[ZJOI2015]幻想乡战略游戏

· · 个人记录

主要问题是,每次修改之后,如何快速地找到新的重心。

我们建出点分树。每次从点分树的根开始遍历:

假设当前在 p 节点,那么遍历 p原树上的儿子 y,通过点分树暴力计算 y 作为重心时的答案(具体见下文)。找到所有儿子答案的最小值 mi。当 mip 作为根时的答案更优时(显然这时候只有一个儿子的答案比 p 优),那么就跳到这个儿子在点分树上的根上,继续执行此操作,直到 mi 不必当前答案优。

这个做法可以类比在序列上二分,序列上目前的二分范围 [l,r] 好比当前所在点分树的子树,mid 好比重心。其本质都是利用“单调性”,每次来“找中间点”来将复杂度降到 \log 级别。

复杂度:点分树深度 O(\log n),计算一个点的答案 O(\log n),点的度数 O(20),总复杂度 O(q \times \log^2 n \times 20)

代码(\text{bl} 表示原树上 xy 的边所找到的点分树上的根 rt,具体实现看总代码):

int query(int x)
{
    int p=x;
    st[0]=0;
    while(1)
    {
        bj[p]=1,st[++st[0]]=p;
        ll mi=1e18;
        int mix=0;
        for(int i=h[p];i;i=e[i].Next)
        {
            int y=e[i].to;
            if(bj[y]) continue;
            ll v=calc(y);
            if(v<mi) mi=v,mix=bl[e[i].id];
        }
        if(mi<calc(p)) p=mix;
        else break;
    }
    for(int i=1;i<=st[0];i++) bj[st[i]]=0;
    return p;
}

接下来考虑如何计算所有点到定点 x 的距离乘上点权之和。

类似 [HNOI2015]开店 的思路,我们建出点分树,然后从他开始往上跳,当跳到一个点 p,记 tot_p 点分树上以 p 为根的子树内的点权和,\text{s1}_i 表示所有点到 p 的点权乘距离和,\text{s2}_i 表示 p 子树内所有点到 p 父亲的点权乘距离和。那么我们将答案加上 \text{s1}_p + tot_p \times \text{dist}(p,x),再找到 p 的一个儿子 q,满足 xq 的子树内,将答案减去 \text{s2}_q + tot_q \times \text{dist}(p,x),即可完成计算答案。代码可参考:

ll calc(int x)
{
    int p=x,q=0;
    ll res=0;
    while(p)
    {
        res+=s1[p]+tot[p]*getdis(x,p);
        if(q) res-=s2[q]+tot[q]*getdis(x,p);
        q=p,p=fa[p];
    }
    return res;
}

那么每次修改就可以在点分树上往上跳,更新 tot,s1,s2 的值:

void change(int x,int k)
{
    int p=x,q=0;
    while(p)
    {
        tot[p]+=k;
        s1[p]+=getdis(x,p)*k;
        if(q) s2[q]+=getdis(x,p)*k;
        q=p,p=fa[p];
    }
}

总代码:

#include <bits/stdc++.h>
#ifdef LOCAL
#include "txm/debug.h"
#endif
using namespace std;
typedef long long ll;
const int N=100005,INF=0x3f3f3f3f;
struct edge {int to,Next,v,id;}e[N<<1];
int n,q,a[N],h[N],cnt,d[N],sz[N];
int top[N],prt[N],son[N],dep[N],fa[N],bl[N];
int mi,rt,siz,st[N];
ll tot[N],sum[N],s1[N],s2[N];
bool v[N],bj[N];
void Addedge(int x,int y,int z,int id) {e[++cnt]=(edge){y,h[x],z,id},h[x]=cnt;}
void dfs1(int x)
{
    sz[x]=1,dep[x]=dep[prt[x]]+1;
    for(int i=h[x];i;i=e[i].Next)
    {
        int y=e[i].to;
        if(y==prt[x]) continue;
        prt[y]=x;
        sum[y]=sum[x]+e[i].v;
        dfs1(y),sz[x]+=sz[y];
        if(sz[y]>sz[son[x]]) son[x]=y;
    }
}
void dfs2(int x,int tp)
{
    top[x]=tp;
    if(!son[x]) return;
    dfs2(son[x],tp);
    for(int i=h[x];i;i=e[i].Next)
    {
        int y=e[i].to;
        if(y!=prt[x]&&y!=son[x]) dfs2(y,y);
    }
}
int lca(int x,int y)
{
    while(top[x]!=top[y]) dep[top[x]]>dep[top[y]]?(x=prt[top[x]]):(y=prt[top[y]]);
    return dep[x]<dep[y]?x:y;
}
ll getdis(int x,int y) {return sum[x]+sum[y]-2*sum[lca(x,y)];}
void findrt(int x,int p)
{
    sz[x]=1;
    int mx=0;
    for(int i=h[x];i;i=e[i].Next)
    {
        int y=e[i].to;
        if(y==p||v[y]) continue;
        findrt(y,x);
        sz[x]+=sz[y];
        mx=max(mx,sz[y]);
    }
    mx=max(mx,siz-sz[x]);
    if(mx<mi) mi=mx,rt=x;
}
void getsize(int x,int p)
{
    siz++;
    for(int i=h[x];i;i=e[i].Next)
        if(e[i].to!=p&&!v[e[i].to]) getsize(e[i].to,x);
}
void build(int x)
{
    v[x]=1;
    for(int i=h[x];i;i=e[i].Next)
    {
        int y=e[i].to;
        if(v[y]) continue;
        siz=0,getsize(y,x);
        mi=INF,findrt(y,x);
        bl[e[i].id]=rt;
        fa[rt]=x,build(rt);
    }
}
void change(int x,int k)
{
    int p=x,q=0;
    while(p)
    {
        tot[p]+=k;
        s1[p]+=getdis(x,p)*k;
        if(q) s2[q]+=getdis(x,p)*k;
        q=p,p=fa[p];
    }
}
ll calc(int x)
{
    int p=x,q=0;
    ll res=0;
    while(p)
    {
        res+=s1[p]+tot[p]*getdis(x,p);
        if(q) res-=s2[q]+tot[q]*getdis(x,p);
        q=p,p=fa[p];
    }
    return res;
}
int query(int x)
{
    int p=x;
    st[0]=0;
    while(1)
    {
        bj[p]=1,st[++st[0]]=p;
        ll mi=1e18;
        int mix=0;
        for(int i=h[p];i;i=e[i].Next)
        {
            int y=e[i].to;
            if(bj[y]) continue;
            ll v=calc(y);
            if(v<mi) mi=v,mix=bl[e[i].id];
        }
        if(mi<calc(p)) p=mix;
        else break;
    }
    for(int i=1;i<=st[0];i++) bj[st[i]]=0;
    return p;
}
signed main()
{
    scanf("%d %d",&n,&q);
    for(int i=1,x,y,z;i<n;i++)
    {
        scanf("%d %d %d",&x,&y,&z);
        Addedge(x,y,z,i);
        Addedge(y,x,z,i);
    }
    dfs1(1),dfs2(1,1);
    siz=n,mi=INF,findrt(1,0);
    int RT=rt;
    build(rt);
    int x,y,lst=1;
    while(q--)
    {
        scanf("%d %d",&x,&y);
        change(x,y);
        printf("%lld\n",calc(query(RT)));
    }
    return 0;
}