树的直径总结

· · 个人记录

树的直径是指树上每两点间距离的最大值。

方法一:以任意一点作为起点遍历,找到一个与之距离最远的点,该点为直径的其中一个端点,再一这个点为起点遍历,找到一个最远的点,两点间的距离为树的直径大小。时间复杂度O(n)。

树中距离某一直径端点最远的点,至少有一个是该直径的另一个端点。

证明:如果不是,那么我们一定可以用与之距离最远的点更新直径。

方法二:易证直径的两个端点一定是叶子节点,任选一点为根节点建树,求得根节点到每个叶子节点的距离dis[i]。选择两个叶子节点x,y,找到它们的最近公共祖先lca,则两点间的距离为dis[x]+dis[y]-dis[lca]*2,每两点间距离的最大值则为直径。时间复杂度O(n^2)

【例题】P3304

上图为样例,易发现3-2和3-6为树的直径,要找到所有直径都经过的边,可先任选一条直径,从一个端点(l)访问其上的每一个点,若到其他不在这个直径上的点的最大距离等于它到另一端点(r)的距离,则从该点到r的边一定不是公共边。例如图中先选定3-6直径,在访问到4时,4到2的距离等于4到6的距离,因此4到6一定不是公共边。从l到r遍历玩后,再从r到l遍历一遍,最后即可得到答案。

#include<bits/stdc++.h>
#define int long long
#define N 200005
using namespace std;
int n,ans,ans2,maxx,sum,f[N],a[N],fa[N],dep[N],vis[N],dis[N],nex[N],last[N];
struct node{
    int to,v;
};
vector<node>s[N];
int find(int x,int y){
    if(dep[x]>dep[y])swap(x,y);
    while(dep[x]!=dep[y])y=fa[y];
    while(fa[x]!=fa[y]){
        x=fa[x];
        y=fa[y];
    }
    return fa[x];
}
void gets(int x,int y,int lca){
    int now=x;
    while(now!=lca){
        sum++;
        now=fa[now];
    }
    now=y;
    while(now!=lca){
        sum++;
        now=fa[now];
    }
}
void dfs(int x,int bef){
    fa[x]=bef;
    dep[x]=dep[bef]+1;
    for(int i=0;i<s[x].size();i++){
        int to=s[x][i].to;
        if(to==bef)continue;
        f[to]=f[x]+s[x][i].v;
        dfs(to,x);
    }
}
void dfs2(int x,int sta){
    for(int i=0;i<s[x].size();i++){
        int to=s[x][i].to;
        if(f[to]||to==sta)continue;
        f[to]=f[x]+s[x][i].v;
        dfs2(to,sta);
    }
}
void dfs3(int x,int bef){
    for(int i=0;i<s[x].size();i++){
        int to=s[x][i].to;
        if(vis[to]||to==bef)continue;
        dis[to]=dis[x]+s[x][i].v;
        maxx=max(maxx,dis[to]);
        dfs3(to,x);
    }
}
signed main()
{
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v,w;
        scanf("%lld%lld%lld",&u,&v,&w);
        s[u].push_back((node){v,w});
        s[v].push_back((node){u,w});
    }
    dfs(1,0);
    int l,r;
    maxx=0;
    for(int i=1;i<=n;i++){
        if(f[i]>=maxx){
            l=i;
            maxx=f[i];
        }
    }
    memset(f,0,sizeof(f));
    dfs2(l,l);
    int ans=0,lca;
    for(int i=1;i<=n;i++){
        if(f[i]>ans){
            ans=f[i];
            r=i;
        }
    }
    lca=find(l,r);
    gets(l,r,lca);
    int x=l;
    vis[lca]=1;
    while(x!=lca){
        nex[x]=fa[x];
        vis[x]=1;
        x=fa[x];
    }
    x=r;
    while(x!=lca){
        nex[fa[x]]=x;
        vis[x]=1;
        x=fa[x];
    }
    x=l;
    while(nex[x]){
        last[nex[x]]=x;
        x=nex[x];
    }
    maxx=0;
    x=l;
    dfs3(x,x);
    while(maxx!=f[r]-f[x]){
        dis[x]=0;
        maxx=0;
        ans2++;
        x=nex[x];
        dfs3(x,x);
    }
    memset(dis,0,sizeof(dis));
    maxx=0;
    x=r;
    dfs3(x,x);
    while(maxx!=f[x]){
        dis[x]=0;
        maxx=0;
        ans2++;
        x=last[x];
        dfs3(x,x);
    }
    printf("%lld\n%lld",ans,max((int)0,ans2-sum));
    return 0;
}