[CSPS 2019] [重心] [树形数据结构] [容斥] 树的重心

· · 个人记录

题意

给定一棵 n 个点的树 T。将该树的任意一条边断开后,可得到两棵新的分裂子树,求分别断开每条边后所得子树的重心编号之和。

思路

在此之前,大家应知道重心的判断式:对于任意 u\in T,都满足 siz_u\le \frac {|T|}2

直观做法:遍历所有边,累加断开后两部分的重心编号。

显然复杂度为 O(n^2) 不可过,那不妨换个角度从点出发,计算有多少条边断开后使得该点为重心,定义 i 的合法断边的数量为 cnt_i,答案即为 \sum cnt_i*i

但是我们还是没有头绪,如何快速计算出 cnt_i 呢?按照 CCF 的出题风格可知,不知道咋做时不妨挖掘性质。

代码实现

所以步骤就是:找重心定根 \to 建新树并预处理 \to 在做一次 \rm DFS 算贡献。

还有血的教训:

  1. 不要忘了初始化。

已经讲得差不多了,就不加注释咯。

#include<bits/stdc++.h>
using namespace std;

#define int long long
const int N=3e5+5;
int t,n,u,v,tot,head[N];
int rt,siz[N],g[N],p[N],ans,ma1,ma2,f[N];
struct qxx{
    int v,nxt;
}e[N<<1];
struct tree{
    int t[N];
    inline void clr() {for(int i=0;i<=N-5;i++){t[i]=0;}}
    inline int lowbit(int a) {return a&-a;}
    inline void upd(int a,int k) {a++;while(a<=n+1){t[a]+=k;a+=lowbit(a);}}
    inline int qry(int a) {a++;int res=0;while(a){res+=t[a];a-=lowbit(a);}return res;}
    inline int get(int l,int r) {return qry(max(r,0LL))-qry(max(l-1,0LL));}
}t1,t2;

inline void add(int u,int v){
    e[++tot]={v,head[u]};
    head[u]=tot;
}
inline void find_rt(int u,int fa){
    bool flag=1;
    siz[u]=1;
    for(int i=head[u]; i;i=e[i].nxt){
        if(e[i].v^fa){
            find_rt(e[i].v,u);
            siz[u]+=siz[e[i].v];
            if(siz[e[i].v]>n/2) flag=0;
        }
    }
    if(n-siz[u]<=n/2&&flag) rt=u;
} 
inline void dfs(int u,int fa){
    siz[u]=1; g[u]=0;
    for(int i=head[u]; i;i=e[i].nxt){
        if(e[i].v^fa){
            dfs(e[i].v,u);
            g[u]=max(g[u],siz[e[i].v]);
            siz[u]+=siz[e[i].v];
        }
    }
} 
inline void find_fir_sec(){
    ma1=-1; ma2=-1;
    for(int i=head[rt]; i;i=e[i].nxt){
        if(siz[e[i].v]>=siz[ma1]){
            ma2=ma1; 
            ma1=e[i].v; 
        } 
        else{ 
            if(siz[e[i].v]>=siz[ma2]){ 
                ma2=e[i].v;
            } 
        } 
    } 
} 
inline void dfs2(int u,int fa){ 
    t1.upd(siz[u],-1); t1.upd(n-siz[u],1); t2.upd(siz[u],1); 
    if(fa==ma1||f[fa]==1){
        f[u]=1;
    }
    if(u^rt){ 
        ans+=u*t1.get(n-siz[u]*2,n-g[u]*2); 
        ans+=u*t2.get(n-siz[u]*2,n-g[u]*2); 
        if(f[u]==1){
            if(2*max(siz[ma1]-siz[u],siz[ma2])<=n-siz[u]) ans+=rt;
        }
        else{
            if(u^ma1&&2*siz[ma1]<=n-siz[u]) ans+=rt;
            if(u==ma1&&2*siz[ma2]<=n-siz[ma1]) ans+=rt;
        }
    } 
    for(int i=head[u]; i;i=e[i].nxt){ 
        if(e[i].v^fa){ 
            dfs2(e[i].v,u); 
        }   
    }
    if(u^rt){
        ans-=u*t2.get(n-siz[u]*2,n-g[u]*2);
    }
    t1.upd(siz[u],1); t1.upd(n-siz[u],-1);
}
signed main(){
    cin>>t;
    while(t--){
        memset(head,0,sizeof head);
        memset(f,0,sizeof f);
        rt=0; ans=0; tot=0;
        t1.clr(); t2.clr();

        cin>>n;
        for(int i=1; i<n;i++){
            scanf("%lld%lld",&u,&v);
            add(u,v); add(v,u);
        }
        find_rt(1,0);
        dfs(rt,0);
        find_fir_sec();

        for(int i=1; i<=n;i++) t1.upd(siz[i],1);
        dfs2(rt,0);
        printf("%lld\n",ans);
    }
    return 0;
}