wbw的笔记——树的直径

· · 算法·理论

树的直径

树的直径是指树上距离最大的两个结点之间的路径。

注意:

  1. 树的直径不一定唯一;
  2. 树的直径的一个端点一定是叶子结点;
  3. 树的直径若有多条,一定交汇于一点;
  4. 树上任意结点 x,距离 x 最远的点一定是树的直径的端点。

树的直径的求取方法

算法 12DFS

  1. 任取一点 idfs 找到距 i 最远的点 x
  2. 从点 x 出发,dfs 找到距 x 最远的点 y

时间复杂度 O(N)

注意:该方法不能适用于边权为负的情形!

代码:

void dfs(int x, int fa, int sum, int& ans)
{
    if (sum > dis)
    {
        ans = x;
        dis = sum;
    }
    for (auto [nxt, w] : nbr[x])
        if (nxt != fa)
            dfs(nxt, x, sum + w, ans);
    return;
}

算法 2:树形 DP

原理:枚举树树的直径端点的 lca,从 lca 的子树中拼接最长的链。

  1. 树形 dp 遍历每个点 i
  2. 维护最长链的长度 dis1[i] 和次长链 dis2[i]
  3. 枚举每个点 i,取 \max(dis1[i]+dis2[i]) 即为树的直径长度;

时间复杂度 O(N)

代码:

void dfs(int x, int fa)
{
    for (auto [nxt, w] : nbr[x])
        if (nxt != fa)
        {
            dfs(nxt, x);
            w = dis1[nxt] + w;
            if (w > dis1[x])
            {
                dis2[x] = dis1[x];
                dis1[x] = w;
            }
            else if (w > dis2[x])
                dis2[x] = w;
        }
    len = max(len, dis1[x] + dis2[x]);
    return;
}

例题

U81904 【模板】树的直径

代码:

#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 5e5 + 5;
int n, len, dis1[N], dis2[N];
vector<pair<int, int>>nbr[N];
void dfs(int x, int fa)
{
    for (auto [nxt, w] : nbr[x])
        if (nxt != fa)
        {
            dfs(nxt, x);
            w = dis1[nxt] + w;
            if (w > dis1[x])
            {
                dis2[x] = dis1[x];
                dis1[x] = w;
            }
            else if (w > dis2[x])
                dis2[x] = w;
        }
    len = max(len, dis1[x] + dis2[x]);
    return;
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n;
    for (int i = 1; i < n; i++)
    {
        int x, y, z;
        cin >> x >> y >> z;
        nbr[x].push_back({ y,z });
        nbr[y].push_back({ x,z });
    }
    dfs(1, 0);
    cout << len;
    return 0;
}

CF734E Anton and Tree

  1. 同一个联通块的元素和联通块的另一个元素的边权可以设为 0,不同的设为 1,跑树形 dp

代码:

#include<bits/stdc++.h>
typedef int int32;
#define int long long
using namespace std;
const int N = 2e5 + 5;
int n, dis, a[N];
vector<pair<int, int>>nbr[N];
void dfs(int x, int fa, int sum, int& ans)
{
    if (sum >= dis)
    {
        ans = x;
        dis = sum;
    }
    for (auto [nxt, w] : nbr[x])
        if (nxt != fa)
            dfs(nxt, x, sum + w, ans);
    return;
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i < n; i++)
    {
        int x, y, z;
        cin >> x >> y;
        if (a[x] == a[y])
            z = 0;
        else
            z = 1;
        nbr[x].push_back({ y,z });
        nbr[y].push_back({ x,z });
    }
    int x, y;
    dfs(1, 0, 0, x);
    dis = 0;
    dfs(x, 0, 0, y);
    cout << (dis + 1) / 2 << '\n';
    return 0;
}