wbw的笔记——树的直径
wbw_121124 · · 算法·理论
树的直径
树的直径是指树上距离最大的两个结点之间的路径。
注意:
- 树的直径不一定唯一;
- 树的直径的一个端点一定是叶子结点;
- 树的直径若有多条,一定交汇于一点;
- 树上任意结点
x ,距离x 最远的点一定是树的直径的端点。
树的直径的求取方法
算法 1 :2 遍 DFS
- 任取一点
i ,dfs 找到距i 最远的点x ; - 从点
x 出发,dfs 找到距x 最远的点y ; -
时间复杂度
注意:该方法不能适用于边权为负的情形!
代码:
void dfs(int x, int fa, int sum, int& ans)
{
if (sum > dis)
{
ans = x;
dis = sum;
}
for (auto [nxt, w] : nbr[x])
if (nxt != fa)
dfs(nxt, x, sum + w, ans);
return;
}
算法 2 :树形 DP
原理:枚举树树的直径端点的
- 树形
dp 遍历每个点i ; - 维护最长链的长度
dis1[i] 和次长链dis2[i] ; - 枚举每个点
i ,取\max(dis1[i]+dis2[i]) 即为树的直径长度;
时间复杂度
代码:
void dfs(int x, int fa)
{
for (auto [nxt, w] : nbr[x])
if (nxt != fa)
{
dfs(nxt, x);
w = dis1[nxt] + w;
if (w > dis1[x])
{
dis2[x] = dis1[x];
dis1[x] = w;
}
else if (w > dis2[x])
dis2[x] = w;
}
len = max(len, dis1[x] + dis2[x]);
return;
}
例题
U81904 【模板】树的直径
代码:
#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 5e5 + 5;
int n, len, dis1[N], dis2[N];
vector<pair<int, int>>nbr[N];
void dfs(int x, int fa)
{
for (auto [nxt, w] : nbr[x])
if (nxt != fa)
{
dfs(nxt, x);
w = dis1[nxt] + w;
if (w > dis1[x])
{
dis2[x] = dis1[x];
dis1[x] = w;
}
else if (w > dis2[x])
dis2[x] = w;
}
len = max(len, dis1[x] + dis2[x]);
return;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i < n; i++)
{
int x, y, z;
cin >> x >> y >> z;
nbr[x].push_back({ y,z });
nbr[y].push_back({ x,z });
}
dfs(1, 0);
cout << len;
return 0;
}
CF734E Anton and Tree
- 同一个联通块的元素和联通块的另一个元素的边权可以设为
0 ,不同的设为1 ,跑树形dp 。
代码:
#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 2e5 + 5;
int n, dis, a[N];
vector<pair<int, int>>nbr[N];
void dfs(int x, int fa, int sum, int& ans)
{
if (sum >= dis)
{
ans = x;
dis = sum;
}
for (auto [nxt, w] : nbr[x])
if (nxt != fa)
dfs(nxt, x, sum + w, ans);
return;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i < n; i++)
{
int x, y, z;
cin >> x >> y;
if (a[x] == a[y])
z = 0;
else
z = 1;
nbr[x].push_back({ y,z });
nbr[y].push_back({ x,z });
}
int x, y;
dfs(1, 0, 0, x);
dis = 0;
dfs(x, 0, 0, y);
cout << (dis + 1) / 2 << '\n';
return 0;
}