树剖

· · 题解

刚刚学树剖的时候,因为有一些图论的知识,所以理论知识很快就明白了,但代码码了很久,蒟蒻的代码能力还是太垃圾了。

废话就不多说,先在我就谈谈树剖;

先上一道题吧(洛谷P3384)(大佬勿喷) 看到这道题大家有什么想法呢?暴力?(n<=10^5 )(一看就知道用 logn 的算法)a不了吧,废话不多说,正式进入正题。 什么是树剖,树剖就是它的字面意思,将树剖成很多部分然后用线段树维护。

在我们来谈如何进行具体操作之前,我们要来学习一下一些与树剖有关的基础树的知识。

重儿子:num[u]为v的子节点中num值最大的,那么u就是v的重儿子。 轻儿子:v的其它子节点。 重边:点v与其重儿子的连边。 轻边:点v与其轻儿子的连边。 重链:由重边连成的路径。

轻链:轻边。

(本图摘自洛谷)

图中黑色的边是重边,红色的点是重儿子;

剖分后的树有如下性质: 性质1:如果(v,u)为轻边,则siz[u] * 2 < siz[v];

性质2:从根到某一点的路径上轻链、重链的个数都不大于logn。

大概了解了些基本知识,我们就来谈谈那道题到底怎么求解吧;

首先,我们要进行两次dfs,第一次dfs将他的fa、deep、size、son//求的便是英文字母的字面意思

话不多说,上代码;

void dfs1(long long x)//当前所在节点,开始及是根
{
    si[x]=1;//将节点的子树大小初始化为1,及自己
    for(long long i=0;i<q[x].size();++i)//遍历自己的儿子
    {
        if(q[x][i]==fa[x])//不为爹
            continue;
        dep[q[x][i]]=dep[x]+1;//儿子深度为爸爸深度加1
        fa[q[x][i]]=x;//把儿子的爸爸找到
        dfs1(q[x][i]);//继续dfs儿子
        si[x]+=si[q[x][i]];//递归回来x的树大小便是儿子的树大小相加再加上自己的大小(
自己的大小便是1啊)}}

dfs1完了,就进行dfs2,dfs2有什么用呢?

即将树按照重链重新编个号(我代码定义的pos),还要与此同时找到每一个点的链顶(我代码定义的top)

代码如下:

void dfs2(long long x,long long grfa)//x为当前节点,grfa及当前节点的链顶
{
    long long k=0;//判断重链的计数器
    sz++;//当前加到的节号
    pos[x]=sz;
    top[x]=grfa;
    for(long long i=0;i<q[x].size();++i)//遍历儿子
    {
        if(dep[x]<dep[q[x][i]] and si[q[x][i]]>si[k])//找重儿子
        {
            k=q[x][i];
        }
    }
    if(k==0)//没有儿子
        return;
    dfs2(k,grfa);//先递归把重链找到
    for(long long i=0;i<q[x].size();++i)//遍历不是重儿子的其它儿子
    {
        if(dep[x]<dep[q[x][i]] and k!=q[x][i])
            dfs2(q[x][i],q[x][i]);
    }
}

我们将预处理搞好了,接下来便是用线段树来维这一棵树(以下是线段树的模板,我就不详细讲了):

struct node{long long l,r,ls,rs,c,f;}tr[maxn];
//l,r为区间,ls左节点,rs右节点,c是data,f是lazy标记
long long nw(long long l,long long r)//蒟蒻的laji线段树//新开区间为l到r的节点
{
    tr[++t]=(node){l,r,0,0,0,0};
    return t;//返回节点号
}
void xf(long long o)//下放lazy标记
{
    long long l=tr[o].ls,r=tr[o].rs;
    tr[o].c+=(tr[o].r-tr[o].l+1)*tr[o].f%mo;
    tr[l].f+=tr[o].f;
    tr[r].f+=tr[o].f;
    tr[o].f=0;
}
void gx(long long o){//将节点data递归会来
    long long l=tr[o].ls,r=tr[o].rs;
    xf(l),xf(r);
    tr[o].c=(tr[l].c+tr[r].c)%mo;
}
void xg(long long o,long long l,long long r,long long c)//将l到r区间加上c
{
    xf(o);
    if(tr[o].l==l&&tr[o].r==r)
    {
        tr[o].f+=c;
        return;
    }
    long long mid=tr[o].l+tr[o].r>>1;
    if(r<=mid)
        xg(tr[o].ls?tr[o].ls:tr[o].ls=nw(tr[o].l,mid),l,r,c);
    else
    if(l>mid)
        xg(tr[o].rs?tr[o].rs:tr[o].rs=nw(mid+1,tr[o].r),l,r,c);
    else
        xg(tr[o].ls,l,mid,c),xg(tr[o].rs,mid+1,r,c);
    gx(o);
}
long long cx(long long o,long long l,long long r)//求l到r的值
{
    xf(o);
    if(tr[o].l==l&&tr[o].r==r)
        return tr[o].c%mo;
    long long mid=tr[o].l+tr[o].r>>1;
    if(r<=mid)
        return cx(tr[o].ls,l,r);
    else
    if(l>mid)
        return cx(tr[o].rs,l,r);
    else
        return (cx(tr[o].ls,l,mid)+cx(tr[o].rs,mid+1,r))%mo;
}

下面代码便是如何将树和线段树结合其来:

long long dianhe(long long x,long long y)//求两点间的距离
{
    long long ans=0;
    while(top[x]!=top[y])//如果两点不在一条链上
    {
        if(dep[top[x]]<dep[top[y]])//保证x的链顶比y的链顶深
        swap(x,y);
        ans+=cx(1,pos[top[x]],pos[x]);//直接将x到x的链顶的距离用线段树求和
        x=fa[top[x]];//将x跳到链顶的爸爸
    }
    if(dep[x]<dep[y])//当到一条链上,保证x的深度>y
        swap(x,y);
    ans+=cx(1,pos[y],pos[x]);//直接用线段树查询x到y的距离
    return ans%mo;
}
void dianchage(long long x,long long y,long long z)//将x和y两点之间加上z
{
    while(top[x]!=top[y])//就不多说了,和上面一样,将cx改为修改及可
    {
        if(dep[top[x]]<dep[top[y]])
        swap(x,y);
        xg(1,pos[top[x]],pos[x],z);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y])
        swap(x,y);
    xg(1,pos[y],pos[x],z);
    return;
}
long long shuhe(long long x)//求树的和
{
    return cx(1,pos[x],pos[x]+si[x]-1);//因为在树上树儿子编号是连续的,直接区间求和及可
}
void shuchage(long long x,long long y)//将子树每个点加上y
{
    xg(1,pos[x],pos[x]+si[x]-1,y);//同上
    return;
}

如果将上面的看明白了,树剖就差不多入门了(蒟蒻也不可能给你们讲什么高级的东西)

下面我就直接上完整代码吧

#include<bits/stdc++.h>
using namespace std;
const long long maxn=400100;
struct node
{
    long long l,r,ls,rs,c,f;
}tr[maxn];
long long n,m,root,mo,sz;
long long a[maxn],si[maxn],dep[maxn],fa[maxn],pos[maxn],top[maxn];
vector<long long>q[maxn];
long long t;
long long nw(long long l,long long r)
{
    tr[++t]=(node){l,r,0,0,0,0};
    return t;
}
void xf(long long o)
{
    long long l=tr[o].ls,r=tr[o].rs;
    tr[o].c+=(tr[o].r-tr[o].l+1)*tr[o].f%mo;
    tr[l].f+=tr[o].f;
    tr[r].f+=tr[o].f;
    tr[o].f=0;
}
void gx(long long o){
    long long l=tr[o].ls,r=tr[o].rs;
    xf(l),xf(r);
    tr[o].c=(tr[l].c+tr[r].c)%mo;
}
void xg(long long o,long long l,long long r,long long c)
{
    xf(o);
    if(tr[o].l==l&&tr[o].r==r)
    {
        tr[o].f+=c;
        return;
    }
    long long mid=tr[o].l+tr[o].r>>1;
    if(r<=mid)
        xg(tr[o].ls?tr[o].ls:tr[o].ls=nw(tr[o].l,mid),l,r,c);
    else
    if(l>mid)
        xg(tr[o].rs?tr[o].rs:tr[o].rs=nw(mid+1,tr[o].r),l,r,c);
    else
        xg(tr[o].ls,l,mid,c),xg(tr[o].rs,mid+1,r,c);
    gx(o);
}
long long cx(long long o,long long l,long long r)
{
    xf(o);
    if(tr[o].l==l&&tr[o].r==r)
        return tr[o].c%mo;
    long long mid=tr[o].l+tr[o].r>>1;
    if(r<=mid)
        return cx(tr[o].ls,l,r);
    else
    if(l>mid)
        return cx(tr[o].rs,l,r);
    else
        return (cx(tr[o].ls,l,mid)+cx(tr[o].rs,mid+1,r))%mo;
}
void dfs1(long long x)
{
    si[x]=1;
    for(long long i=0;i<q[x].size();++i)
    {
        if(q[x][i]==fa[x])
            continue;
        dep[q[x][i]]=dep[x]+1;
        fa[q[x][i]]=x;
        dfs1(q[x][i]);
        si[x]+=si[q[x][i]];
    }
}
void dfs2(long long x,long long grfa)
{
    long long k=0;
    sz++;
    pos[x]=sz;
    top[x]=grfa;
    for(long long i=0;i<q[x].size();++i)
    {
        if(dep[x]<dep[q[x][i]] and si[q[x][i]]>si[k])
        {
            k=q[x][i];
        }
    }
    if(k==0)
        return;
    dfs2(k,grfa);
    for(long long i=0;i<q[x].size();++i)
    {
        if(dep[x]<dep[q[x][i]] and k!=q[x][i])
            dfs2(q[x][i],q[x][i]);
    }
}
long long dianhe(long long x,long long y)
{
    long long ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
        swap(x,y);
        ans+=cx(1,pos[top[x]],pos[x]);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y])
        swap(x,y);
    ans+=cx(1,pos[y],pos[x]);
    return ans%mo;
}
void dianchage(long long x,long long y,long long z)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
        swap(x,y);
        xg(1,pos[top[x]],pos[x],z);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y])
        swap(x,y);
    xg(1,pos[y],pos[x],z);
    return;
}
long long shuhe(long long x)
{
    return cx(1,pos[x],pos[x]+si[x]-1);
}
void shuchage(long long x,long long y)
{
    xg(1,pos[x],pos[x]+si[x]-1,y);
    return;
}
int main()
{
    scanf("%lld%lld%lld%lld",&n,&m,&root,&mo);
    for(long long i=1;i<=n;++i)
    {
        scanf("%d",&a[i]);
    }
    long long x,y;
    nw(1,n);//建节点
    for(long long i=1;i<n;++i)
    {
        scanf("%lld%lld",&x,&y);
        q[y].push_back(x);
        q[x].push_back(y);
    }
    dfs1(root);//预处理
    dfs2(root,root);
    for(long long i=1;i<=n;++i)
    {
        xg(1,pos[i],pos[i],a[i]);
    }
    long long k,h1,h2,h3;
    while(m)
    {
        m--;
        scanf("%lld",&k);
        if(k==1)
        {
            scanf("%lld%lld%lld",&h1,&h2,&h3);
            dianchage(h1,h2,h3);
        }
        if(k==2)
        {
            scanf("%lld%lld",&h1,&h2);
            printf("%lld\n",dianhe(h1,h2)%mo);
        }
        if(k==3)
        {
            scanf("%lld%lld",&h1,&h2);
            shuchage(h1,h2);
        }
        if(k==4)
        {
            scanf("%lld",&h1);
            printf("%lld\n",shuhe(h1)%mo);
        }
    }
    return 0;
}