树的直径总结
树的直径是指树上每两点间距离的最大值。
方法一:以任意一点作为起点遍历,找到一个与之距离最远的点,该点为直径的其中一个端点,再一这个点为起点遍历,找到一个最远的点,两点间的距离为树的直径大小。时间复杂度O(n)。
树中距离某一直径端点最远的点,至少有一个是该直径的另一个端点。
证明:如果不是,那么我们一定可以用与之距离最远的点更新直径。
方法二:易证直径的两个端点一定是叶子节点,任选一点为根节点建树,求得根节点到每个叶子节点的距离dis[i]。选择两个叶子节点x,y,找到它们的最近公共祖先lca,则两点间的距离为dis[x]+dis[y]-dis[lca]*2,每两点间距离的最大值则为直径。时间复杂度O(n^2)
【例题】P3304
上图为样例,易发现3-2和3-6为树的直径,要找到所有直径都经过的边,可先任选一条直径,从一个端点(l)访问其上的每一个点,若到其他不在这个直径上的点的最大距离等于它到另一端点(r)的距离,则从该点到r的边一定不是公共边。例如图中先选定3-6直径,在访问到4时,4到2的距离等于4到6的距离,因此4到6一定不是公共边。从l到r遍历玩后,再从r到l遍历一遍,最后即可得到答案。
#include<bits/stdc++.h>
#define int long long
#define N 200005
using namespace std;
int n,ans,ans2,maxx,sum,f[N],a[N],fa[N],dep[N],vis[N],dis[N],nex[N],last[N];
struct node{
int to,v;
};
vector<node>s[N];
int find(int x,int y){
if(dep[x]>dep[y])swap(x,y);
while(dep[x]!=dep[y])y=fa[y];
while(fa[x]!=fa[y]){
x=fa[x];
y=fa[y];
}
return fa[x];
}
void gets(int x,int y,int lca){
int now=x;
while(now!=lca){
sum++;
now=fa[now];
}
now=y;
while(now!=lca){
sum++;
now=fa[now];
}
}
void dfs(int x,int bef){
fa[x]=bef;
dep[x]=dep[bef]+1;
for(int i=0;i<s[x].size();i++){
int to=s[x][i].to;
if(to==bef)continue;
f[to]=f[x]+s[x][i].v;
dfs(to,x);
}
}
void dfs2(int x,int sta){
for(int i=0;i<s[x].size();i++){
int to=s[x][i].to;
if(f[to]||to==sta)continue;
f[to]=f[x]+s[x][i].v;
dfs2(to,sta);
}
}
void dfs3(int x,int bef){
for(int i=0;i<s[x].size();i++){
int to=s[x][i].to;
if(vis[to]||to==bef)continue;
dis[to]=dis[x]+s[x][i].v;
maxx=max(maxx,dis[to]);
dfs3(to,x);
}
}
signed main()
{
cin>>n;
for(int i=1;i<n;i++){
int u,v,w;
scanf("%lld%lld%lld",&u,&v,&w);
s[u].push_back((node){v,w});
s[v].push_back((node){u,w});
}
dfs(1,0);
int l,r;
maxx=0;
for(int i=1;i<=n;i++){
if(f[i]>=maxx){
l=i;
maxx=f[i];
}
}
memset(f,0,sizeof(f));
dfs2(l,l);
int ans=0,lca;
for(int i=1;i<=n;i++){
if(f[i]>ans){
ans=f[i];
r=i;
}
}
lca=find(l,r);
gets(l,r,lca);
int x=l;
vis[lca]=1;
while(x!=lca){
nex[x]=fa[x];
vis[x]=1;
x=fa[x];
}
x=r;
while(x!=lca){
nex[fa[x]]=x;
vis[x]=1;
x=fa[x];
}
x=l;
while(nex[x]){
last[nex[x]]=x;
x=nex[x];
}
maxx=0;
x=l;
dfs3(x,x);
while(maxx!=f[r]-f[x]){
dis[x]=0;
maxx=0;
ans2++;
x=nex[x];
dfs3(x,x);
}
memset(dis,0,sizeof(dis));
maxx=0;
x=r;
dfs3(x,x);
while(maxx!=f[x]){
dis[x]=0;
maxx=0;
ans2++;
x=last[x];
dfs3(x,x);
}
printf("%lld\n%lld",ans,max((int)0,ans2-sum));
return 0;
}