P5666

· · 个人记录

[CSP-S2019] 树的重心

所以过了一年多再重新做这道题是什么鬼。。。

刚看的时候会 O(Tn^2) 的暴力,想一想大概有 40pts 。

然后加上链的情况比较好处理,就可以拿到至少 55pts 的好分数。

然后除了这 55pts 我就连口胡都不会了。。。

其实如若只是想通过本题,暴力加上倍增就可以解决。

我们仍然要枚举每一条边,并在每颗子树中寻找重心。不过暴力找的时间复杂度是 O(n) ,如果利用倍增加一些性质寻找就可以优化至 O(\log_2 n)

那么我们可以发现,对于一颗树来说,我们研究根节点,重心在根节点的重子树中;那么在重子树中,它的根节点为总根节点的重儿子,我们研究重儿子,重心一定在重儿子的重子树中,以此类推。结论还算是比较自然。

那么利用这个性质,我们对于每颗子树可以逐层寻找重心,但是为了能够稳定通过此题,可以采用倍增的思想,存储每个节点的 2^i 级的重儿子,就可以在 O(\log_2 n) 下找到重心。

至于针对每一次删边,我们通过回溯的方式修改靠上方的端点与其父节点的关系,并修改相应数据,使这个靠上方端点成为上方子树的根,方便对两个子树的重心寻找工作。

只不过修改其实比较繁复,需要考虑:在删边之后,靠下端点的子树是否是靠上端点的重子树,如若是,则需要用次重子树替换;加入父节点的子树,重子树是否发生变化;等等。不过整个修改的时间复杂度是 O(n\log_2 n) 的。

总时间复杂度 O(n\log_2 n)

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;

const ll N=3e5;

ll T,n,u,v,ans,tot;

ll siz[N+5],hs[N+5][25],fa[N+5],ver[N*2+5],nxt[N*2+5],head[N+5],chs[N+5];

void clr() {
    ans=0;tot=0;
    memset(siz,0,sizeof(siz));memset(fa,0,sizeof(fa));
    memset(hs,0,sizeof(hs));memset(chs,0,sizeof(chs));
    memset(ver,0,sizeof(ver));memset(nxt,0,sizeof(nxt));
    memset(head,0,sizeof(head));
}

ll sol(ll p) {
    ll x=p;
    for(ll k=20;k>=0;k--) {
        if(hs[x][k]==0) continue;
        if(siz[p]-siz[hs[x][k]]<=siz[p]/2) x=hs[x][k];
    }
    ll res=0;
    if(siz[hs[x][0]]<=siz[p]/2&&siz[p]-siz[x]<=siz[p]/2) {
        res+=x;
    }
    if(siz[hs[fa[x]][0]]<=siz[p]/2&&siz[p]-siz[fa[x]]<=siz[p]/2) {
        res+=fa[x];

    }
    return res;
}

void calc(ll p) {
    for(ll i=1;i<=20;i++) {
        hs[p][i]=hs[hs[p][i-1]][i-1];
    }
}

void dfs(ll p,ll fath) {
    siz[p]=1;fa[p]=fath;
    ll tmp_siz=0,tmp_hs=0,tmp_csiz=0,tmp_chs=0;
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dfs(ver[i],p);siz[p]+=siz[ver[i]];
        if(siz[ver[i]]>tmp_siz) {
            tmp_csiz=tmp_siz;tmp_chs=tmp_hs;
            tmp_siz=siz[ver[i]];tmp_hs=ver[i];
        }
        else {
            if(siz[ver[i]]>tmp_csiz) {
                tmp_chs=ver[i];tmp_csiz=siz[ver[i]];
            }
        }
    }
    hs[p][0]=tmp_hs;calc(p);
    chs[p]=tmp_chs;
}

void _dfs(ll p,ll fath) {

    ll tmp1=0,tmp2=0;

    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        ll tmp_fafa=fa[fa[p]];
        ll tmp_siz=siz[p];
        ll tmp_hs=hs[p][0],tmp_chs=chs[p];
        ll tmp_fa=fa[p];
        fa[ver[i]]=0;
        tmp1=sol(ver[i]);fa[ver[i]]=p;ans+=tmp1;
        fa[fa[p]]=p;siz[p]=siz[p]-siz[ver[i]]+siz[fa[p]];
        if(hs[p][0]==ver[i]) hs[p][0]=chs[p];
        if(siz[fa[p]]>siz[hs[p][0]]) hs[p][0]=fa[p];
        calc(p);fa[p]=0;
        tmp2=sol(p);fa[p]=tmp_fa;ans+=tmp2;
        _dfs(ver[i],p);
        fa[fa[p]]=tmp_fafa;
        hs[p][0]=tmp_hs;chs[p]=tmp_chs;siz[p]=tmp_siz;
        calc(p);
    }
}

void add(ll u,ll v) {
    ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    if(x<0) {x=-x;putchar('-');}
    if(x>9) write(x/10);
    putchar(x%10+48);
}

int main() {

    T=read();

    while(T--) {
        n=read();
        for(ll i=1;i<n;i++) {
            u=read();v=read();add(u,v);add(v,u);
        }
        dfs(1,0);
        _dfs(1,0);
        write(ans);putchar('\n');
        clr();
    }

    return 0;
}