P3384

· · 个人记录

【模板】轻重链剖分/树链剖分

首先,我们对整棵树进行 DFS,得到了每个节点对于根节点的深度、父亲、子树大小以及重儿子。

然后,我们再对整棵树进行 DFS,并优先搜索重儿子,从而得到重链、每个节点所在重链的起始节点、以及每个节点在线段树上映射的下标(即 DFS 序)。搜索的原则就是优先重儿子,轻儿子一定为某条重链的起始节点。

接着,我们就可以建立一颗线段树以维护这些信息。

对于操作 1

我们对于每条路径上的修改操作,可以知道每条路径都被分割成了大大小小的重链。假设路径两个端点分别为 xy,不妨设 dt_x>dt_y,那我们就需要向上调整 x。向上调整的方法是:将 x 转移至 其所在链的起始节点 的父节点。如果中途出现深度大小关系发生变化,就交换并重复执行上述操作,并在跳链的过程中修改线段树,直至 xy 处在同一条重链中。此时再修改 xy 之间在线段树上的值,我们就完成了整个修改操作。

感性理解这个路径修改的复杂度。我们把本来很长的路径转移转变为了跳链,跳链固然可以加快速度,但是这个速度到底能加到多快?答案是 O(\log n)。感性证明的话,就是重儿子本身的性质决定了这个跳链的过程中:如果走了轻边,子树的相对大小至少减少一半,决定了经过的轻边至多只有 \log n 条;如果走了重链,显然重链之间是由轻链连接的,所以经过重链的条数亦不会超过 \log n 条。

然后再加上过程中线段树维护的复杂度,总的单次修改路径复杂度为 O(\log^2n)

对于操作 2

求路径权值和与操作 1 是类似的,我们同样需要通过跳链来遍历需要合并的信息,然后在线段树上查找即可。

时间复杂度 O(\log^2 n)

对于操作 3

让子树节点的权值全部加上某个值。实际上我们发现,虽然得到的 DFS 序是优先重儿子的,但不影响对于节点 p,其在 DFS 序中的后 siz_p-1 个节点覆盖了它的子树。同时我们也能知道,对于节点 p,其子树的所有节点在线段树上覆盖了连续的区间,于是我们只要对连续的区间进行加法即可。

时间复杂度 O(\log n)

对于操作 4

同操作 3,直接在线段树上查询区间和即可。

时间复杂度 O(\log n)

至此,本题的四种不同操作得到解决,总体时间复杂度 O(n\log n+m\log^2 n)

代码:

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=1e5;

struct sgt{
    ll l,r,dat,laz;
    #define l(x) tree[x].l
    #define r(x) tree[x].r
    #define dat(x) tree[x].dat
    #define laz(x) tree[x].laz
}tree[N*4+5];

ll n,m,root,mo,tot,cnt,u,v,op,x,y,z;

ll head[N+5],ver[N*2+5],nxt[N*2+5];

ll hs[N+5],id[N+5],fa[N+5],dt[N+5],siz[N+5],top[N+5];

ll a[N+5],wt[N+5];

void build(ll p,ll l,ll r) {
    l(p)=l;r(p)=r;
    if(l==r) {
        dat(p)=wt[l]%mo;return;
    }
    ll mid=(l+r)>>1;
    build(p*2,l,mid);build(p*2+1,mid+1,r);
    dat(p)=(dat(p*2)+dat(p*2+1))%mo;
}

void spread(ll p) {
    if(laz(p)) {
        dat(p*2)=(dat(p*2)+laz(p)*(r(p*2)-l(p*2)+1));
        laz(p*2)=(laz(p*2)+laz(p))%mo;
        dat(p*2+1)=(dat(p*2+1)+laz(p)*(r(p*2+1)-l(p*2+1)+1));
        laz(p*2+1)=(laz(p*2+1)+laz(p))%mo;
        laz(p)=0;
    }
}

ll query(ll p,ll l,ll r) {
    if(l<=l(p)&&r>=r(p)) return dat(p)%mo;
    spread(p);
    ll mid=(l(p)+r(p))>>1,res=0;
    if(l<=mid) res+=query(p*2,l,r);
    if(r>mid) res+=query(p*2+1,l,r);
    return res%mo;
}

void chg(ll p,ll l,ll r,ll k) {
    if(l<=l(p)&&r>=r(p)) {
        laz(p)=(laz(p)+k)%mo;
        dat(p)=(dat(p)+k*(r(p)-l(p)+1)%mo)%mo;
        return;
    }
    spread(p);
    ll mid=(l(p)+r(p))>>1;
    if(l<=mid) chg(p*2,l,r,k);
    if(r>mid) chg(p*2+1,l,r,k);
    dat(p)=(dat(p*2)+dat(p*2+1))%mo;
}

void dfs(ll p,ll fath) {
    dt[p]=dt[fath]+1;fa[p]=fath;siz[p]=1;
    ll ma=0;
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fath) continue;
        dfs(ver[i],p);siz[p]+=siz[ver[i]];
        if(siz[ver[i]]>ma) {
            hs[p]=ver[i];ma=siz[ver[i]];
        }
    }
}

void dfs_(ll p,ll topf) {
    id[p]=++cnt;wt[cnt]=a[p];top[p]=topf;
    if(!hs[p]) return;
    dfs_(hs[p],topf);
    for(ll i=head[p];i;i=nxt[i]) {
        if(ver[i]==fa[p]||ver[i]==hs[p]) continue;
        dfs_(ver[i],ver[i]);
    }
}

void chgpath(ll x,ll y,ll k) {
    k%=mo;
    while(top[x]!=top[y]) {
        if(dt[top[x]]<dt[top[y]]) swap(x,y);
        chg(1,id[top[x]],id[x],k);
        x=fa[top[x]];
    }
    if(dt[x]>dt[y]) swap(x,y);
    chg(1,id[x],id[y],k);
}

ll qpath(ll x,ll y) {
    ll ans=0;
    while(top[x]!=top[y]) {
        if(dt[top[x]]<dt[top[y]]) swap(x,y);
        ans=(ans+query(1,id[top[x]],id[x]))%mo;
        x=fa[top[x]];
    }
    if(dt[x]>dt[y]) swap(x,y);
    ans=(ans+query(1,id[x],id[y]))%mo;
    return ans;
}

void chgsub(ll p,ll k) {
    chg(1,id[p],id[p]+siz[p]-1,k);
}

ll qsub(ll p) {
    return query(1,id[p],id[p]+siz[p]-1);
}

void add(ll u,ll v) {
    ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    if(x<0) {x=-x;putchar('-');}
    ll y=10,len=1;
    while(y<=x) {y*=10;len++;}
    while(len--) {y/=10;putchar(x/y+48);x%=y;}
}

int main() {

    n=read();m=read();root=read();mo=read();

    for(ll i=1;i<=n;i++) a[i]=read();

    for(ll i=1;i<n;i++) {
        u=read();v=read();
        add(u,v);add(v,u);
    }

    dfs(root,0);dfs_(root,root);

    build(1,1,n);

    while(m--) {
        op=read();
        if(op==1) {
            x=read();y=read();z=read();
            chgpath(x,y,z);
        }
        if(op==2) {
            x=read();y=read();
            write(qpath(x,y));putchar('\n');
        }
        if(op==3) {
            x=read();z=read();
            chgsub(x,z);
        }
        if(op==4) {
            x=read();
            write(qsub(x));putchar('\n');
        }
    }

    return 0;
}