树的直径与重心

· · 算法·理论

一.树的直径

定义: 树上任意两节点之间最长的 简单路径即为树的「直径」。

显然,一棵树可以有多条直径,他们的长度相等。

可以用两次 DFS 或者树形 DP 的方法在 O(n) 时间求出树的直径。

  1. 两次 DFS

过程: 首先从任意节点 y 开始进行第一次 DFS,到达距离其最远的节点,记为 z ,然后再从 z 开始做第二次 DFS,到达距离 z 最远的节点,记为 z' ,则 (z,z') 即为树的直径。

)

void dfs(int u, int fa) {
  for (int v : E[u]) {
    if (v == fa) continue;
    d[v] = d[u] + 1;
    if (d[v] > d[c]) c = v;
    dfs(v, u);
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d %d", &u, &v);
    E[u].push_back(v), E[v].push_back(u);
  }
  dfs(1, 0);
  d[c] = 0, dfs(c, 0);
  printf("%d\n", d[c]);
  return 0;
}

2.树形 DP

这里提供一种只使用一个数组进行的树形 DP 方法。

我们定义 dp[u] :以 u 为根的子树中,从 u 出发的最长路径。那么容易得出转移方程: dp[u] = \max(dp[u], dp[v] + w(u, v)) ,其中的 v 为 u 的子节点, w(u, v) 表示所经过边的权重。

对于树的直径,实际上是可以通过枚举从某个节点出发不同的两条路径相加的最大值求出。因此,在 DP 求解的过程中,我们只需要在更新 dp[u] 之前,计算 d = \max(d, dp[u] + dp[v] + w(u, v)) 即可算出直径 d

void dfs(int u, int fa) {
  for (int v : E[u]) {
    if (v == fa) continue;
    dfs(v, u);
    d = max(d, dp[u] + dp[v] + 1);
    dp[u] = max(dp[u], dp[v] + 1);
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d %d", &u, &v);
    E[u].push_back(v), E[v].push_back(u);
  }
  dfs(1, 0);
  printf("%d\n", d);
  return 0;
}

板子:B4016

考试题CF1404B Tree Tag

我们考虑什么情况下Alice可以获胜.

3.开始时 A 和 B 间的距离小于 da ,这样 A 第一步就能追上 B。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm> 
#include<vector>
#define int long long
using namespace std;
const int MAXN=1e5+10;
int n,a,b,da,db,dis[MAXN];
vector<int>e[MAXN];
void dfs(int u,int fa){
    dis[u]=dis[fa]+1;
    for(int i=0;i<e[u].size();i++){
        int v=e[u][i];
        if(v!=fa)dfs(v,u);
    }
}
void solve(){
    cin>>n>>a>>b>>da>>db;
    for(int i=1;i<=n;i++)e[i].clear();
    for(int i=1;i<n;i++){
        int u,v; 
        cin>>u>>v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1,0); 
    int maxn=0;
    for(int i=1;i<=n;i++) 
        if(dis[maxn]<dis[i])maxn=i;
    dfs(maxn,0);
    maxn=0;
    for(int i=1;i<=n;i++) 
        if(dis[maxn]<dis[i])maxn=i;
    if(dis[maxn]<=da*2+1){
        printf("Alice\n");
        return;
    }
    dfs(a,0);
    if(dis[b]<=da+1) {
        printf("Alice\n");
        return;
    }
    if(2*da>=db){
        printf("Alice\n");
        return;
    }
    printf("Bob\n");
}
signed main(){
    int T;
    cin>>T;
    while(T--)solve();
    return 0;
}

二.树的重心

1.树的重心

树的重心是指树上一点,去掉后最大子树可以取得最小值的点。这样定义可能比较抽象,我们来看一个例子——

【无根树】去掉1后,子树最大为4;去掉2后,子树最大为5;去掉3后,没影响;去掉4后,子树最大为7;去掉5后,没影响;去掉6后,子树最大为5;去掉7后,子树最大为6;去掉8、9后,均无影响。

所以这棵树的重心为点1。由此也可以得出重心的另一个定义——去掉该点后最大子树大小不超过 n/2 。这一定义在应用方面比较多。

求法

在 DFS 中计算每个子树的大小,记录「向下」的子树的最大大小,利用总点数 - 当前子树(这里的子树指有根树的子树)的大小得到「向上」的子树的大小,然后就可以依据定义找到重心了。

// 这份代码默认节点编号从 1 开始,即 i ∈ [1,n]
int size[MAXN],  // 这个节点的「大小」(所有子树上节点数 + 该节点)
    weight[MAXN],  // 这个节点的「重量」,即所有子树「大小」的最大值
    centroid[2];  // 用于记录树的重心(存的是节点编号)

void GetCentroid(int cur, int fa) {  // cur 表示当前节点 (current)
  size[cur] = 1;
  weight[cur] = 0;
  for (int i = head[cur]; i != -1; i = e[i].nxt) {
    if (e[i].to != fa) {  // e[i].to 表示这条有向边所通向的节点。
      GetCentroid(e[i].to, cur);
      size[cur] += size[e[i].to];
      weight[cur] = max(weight[cur], size[e[i].to]);
    }
  }
  weight[cur] = max(weight[cur], n - size[cur]);
  if (weight[cur] <= n / 2) {  // 依照树的重心的定义统计
    centroid[centroid[0] != 0] = cur;
  }
}

考试题:CF1406C Link Cut Centroids 分类讨论.

首先当树只有一个重心的时候,我们删掉任意边再加上原边即可.

再看有两个重心的情况.

显然这棵树必定是类似这样的:

即删掉 A 后,以 B 为根的子树是剩下的最大连通块,反之亦然.

那就可以得到一个结论:

删掉边 (A,B) 后,两棵树的大小相等. 那我们只要使两棵树的大小不相等,且不使新的点成为重心即可.

那就考虑直接从 A 树中随机抽取一位幸运叶子节点,把这个节点与它父亲的边断开,连到 B 的直接儿子里去.

这样, A 树的大小变小了,而 B 树的大小变大了,且不会有新的节点成为重心.

A 就不再是重心了,而 B 则成为了唯一的重心.

maxn[u]‌:存储以节点u为根的子树中,最大连通块的大小。

minn‌:记录所有节点中maxn[u]的最小值,即树的重心对应的最大连通块大小

#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
const int MAXN=1000010;
struct Node{
    int v,nxt;
}e[MAXN*2];
int h[MAXN],cnt=0;
void add(int u,int v){
    e[++cnt].v=v;
    e[cnt].nxt=h[u];
    h[u]=cnt;
}
int n,siz[MAXN],maxn[MAXN],minn,a[MAXN],cntt,t;
void dfs(int u,int fa){
    siz[u]=maxn[u]=1;
    for(int i=h[u];i;i=e[i].nxt){
        int v=e[i].v;
        if(v==fa)continue;
        dfs(v,u);
        siz[u]+=siz[v];
        maxn[u]=max(maxn[u],siz[v]);
    }
    maxn[u]=max(maxn[u],n-siz[u]);
    minn=min(maxn[u],minn);
}
int find(int u,int fa){
    for(int i=h[u];i;i=e[i].nxt){
        int v=e[i].v;
        if(v==fa)continue;
        t=u;
        return find(v,u);
    }
    return u;
}
signed main(){
    int T;
    cin>>T;
    while(T--){
        cnt=cntt=0;
        cin>>n;
        minn=n+1;
        for(int i=1;i<=n;i++)h[i]=0;
        for(int i=1;i<n;i++){
            int u,v;
            cin>>u>>v;
            add(u,v);
            add(v,u);
        }
        dfs(1,0);
        for(int i=1;i<=n;i++)
            if(maxn[i]==minn)
                a[++cntt]=i;
        if(cntt==1)
            cout<<a[1]<<" "<<e[h[a[1]]].v<<"\n"<<a[1]<<" "<<e[h[a[1]]].v;

        else{
            int tmp=find(a[2],a[1]);
            cout<<t<<" "<<tmp<<"\n"<<a[1]<<" "<<tmp;
        }
        cout<<'\n';
    }
    return 0;
}