树链剖分

· · 个人记录

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
int n,m,r,p;
struct node{
    int to,nxt;
}e[N];
int w[N];
int h[N];
int tot=0;
void _add(int u,int v)
{
    tot++;
    e[tot].to=v;
    e[tot].nxt=h[u];
    h[u]=tot;
}//建边

int dep[N],siz[N],fa[N],son[N];
void dfs1(int u,int f)//当前节点u,其父节点f
{
    dep[u]=dep[f]+1;//深度
    fa[u]=f;//父节点标记
    siz[u]=1;//子树个数
    int mx=-1;
    for(int i=h[u];i;i=e[i].nxt)
    {
        int v=e[i].to;
        if(v==f)//是父节点就跳
        {
            continue;
        }
        dfs1(v,u);//便利下去
        siz[u]+=siz[v];//子树和
        if(siz[v]>mx)//重链剖分,子树最大的为重边
        {
            mx=siz[v];
            son[u]=v;
        }
    }
    return;
}

int id[N],wt[N],top[N];
int cnt=0;
void dfs2(int u,int tp)
{
    id[u]=++cnt;//给点标号
    wt[cnt]=w[u];//点权转移
    top[u]=tp;//top记录链顶
    if(!son[u])//没有子节点就结束
    {
        return;
    }
    dfs2(son[u],tp);//同一条链最顶端的点不变
    for(int i=h[u];i;i=e[i].nxt)
    {
        int v=e[i].to;
        if(v==fa[u]||v==son[u])//特判
        {
            continue;
        }
        dfs2(v,v);//遍历所有链
    }
    return;
}

int a[N*2],laz[N*2];
int res=0;
void pushdown(int rt,int lenn)//懒标记下放
{
    laz[rt<<1]+=laz[rt];
    laz[rt<<1|1]+=laz[rt];
    a[rt<<1]+=laz[rt]*(lenn-(lenn>>1));
    a[rt<<1|1]+=laz[rt]*(lenn>>1);
    a[rt<<1]%=p;
    a[rt<<1|1]%=p;
    laz[rt]=0;
    return ;
}
void build(int rt,int l,int r)//节点编号rt,区间[l,r]
{
    if(l==r)//单节点
    {
        a[rt]=wt[l];//赋值
        if(a[rt]>p)//取模
        {
            a[rt]%=p;
        }
        return ;
    }
    int mid=(l+r)>>1;//常识,不解释
    build(rt<<1,l,mid);//左
    build(rt<<1|1,mid+1,r);//右
    a[rt]=(a[rt<<1]+a[rt<<1|1])%p;//求和
}
void query(int rt,int l,int r,int lx,int rx)//懒得解释了......
{
    if(lx<=l&&r<=rx)
    {
        res+=a[rt];
        res%=p;
        return ;
    }else{
        if(laz[rt])
        {
            pushdown(rt,r-l+1);
        }
        int mid=(l+r)>>1;
        if(lx<=mid)
        {
            query(rt<<1,l,mid,lx,rx);
        }
        if(rx>mid)
        {
            query(rt<<1|1,mid+1,r,lx,rx);
        }
    }
    return;
}
void update(int rt,int l,int r,int lx,int rx,int k)
{
    if(lx<=l&&r<=rx)//区间内就区间加
    {
        laz[rt]+=k;//懒标记
        a[rt]+=(r-l+1)*k;//区间多少个就要加多少个K
    }else{
        if(laz[rt])//懒标记释放
        {
            pushdown(rt,r-l+1);
        }
        //后面是线段树常规知识,不解释
        int mid=(l+r)>>1;
        if(lx<=mid)
        {
            update(rt<<1,l,mid,lx,rx,k);
        }
        if(rx>mid)
        {
            update(rt<<1|1,mid+1,r,lx,rx,k);
        }
        a[rt]=(a[rt<<1]+a[rt<<1|1])%p;
    }
    return;
}

int qrange(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
        {
            swap(x,y);
        }
        res=0;
        query(1,1,n,id[top[x]],id[x]);
        ans+=res;//求和
        ans%=p;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])
    {
        swap(x,y);
    }
    res=0;
    query(1,1,n,id[x],id[y]);
    ans+=res;
    return ans%p;//取模
}
void updrange(int x,int y,int k)
{
    k%=p;//取模
    while(top[x]!=top[y])//不在同一条链
    {
        if(dep[top[x]]<dep[top[y]])//确保深度上x在y上面
        {
            swap(x,y);
        }
        update(1,1,n,id[top[x]],id[x],k);//能跳完整条链那整条链都要增加,因为他们必然在路径上
        x=fa[top[x]];//上跳
    }
    if(dep[x]>dep[y])
    {
        swap(x,y);//同上
    }
    update(1,1,n,id[x],id[y],k);//不能跳完整条链那将链中跳到的部分增加
    return;
} 

int qrson(int x)
{
    res=0;
    query(1,1,n,id[x],id[x]+siz[x]-1);//树链剖分下来,所有的子树链一定是连续的。
    return res;
}
void updson(int x,int k)
{
    update(1,1,n,id[x],id[x]+siz[x]-1,k);//同上
    return;
}

int main()
{
    scanf("%d%d%d%d",&n,&m,&r,&p);//输入
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&w[i]);//点权
    }
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);//建边
        _add(u,v);
        _add(v,u);
    }
    dfs1(r,0);//dfs序
    dfs2(r,r);//求每条链顶点 
    build(1,1,n);//线段树建树
    while(m--)
    {
        int op,x,y,z;
        scanf("%d",&op);
        if(op==1)//表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
        {
            scanf("%d%d%d",&x,&y,&z);//输入
            updrange(x,y,z);
        }else if(op==2){//表示求树从 x 到 y 结点最短路径上所有节点的值之和。
            scanf("%d%d",&x,&y);
            printf("%d\n",qrange(x,y));
        }else if(op==3){//表示将以 x 为根节点的子树内所有节点值都加上 z。
            scanf("%d%d",&x,&y);
            updson(x,y);
        }else{//表示求以 x 为根节点的子树内所有节点值之和
            scanf("%d",&x);
            printf("%d\n",qrson(x));
        }
    }
    return 0;
}