树剖
刚刚学树剖的时候,因为有一些图论的知识,所以理论知识很快就明白了,但代码码了很久,蒟蒻的代码能力还是太垃圾了。
废话就不多说,先在我就谈谈树剖;
先上一道题吧(洛谷P3384)(大佬勿喷) 看到这道题大家有什么想法呢?暴力?(n<=10^5 )(一看就知道用 logn 的算法)a不了吧,废话不多说,正式进入正题。 什么是树剖,树剖就是它的字面意思,将树剖成很多部分然后用线段树维护。
在我们来谈如何进行具体操作之前,我们要来学习一下一些与树剖有关的基础树的知识。
重儿子:num[u]为v的子节点中num值最大的,那么u就是v的重儿子。 轻儿子:v的其它子节点。 重边:点v与其重儿子的连边。 轻边:点v与其轻儿子的连边。 重链:由重边连成的路径。
轻链:轻边。
(本图摘自洛谷)
图中黑色的边是重边,红色的点是重儿子;
剖分后的树有如下性质: 性质1:如果(v,u)为轻边,则siz[u] * 2 < siz[v];
性质2:从根到某一点的路径上轻链、重链的个数都不大于logn。
大概了解了些基本知识,我们就来谈谈那道题到底怎么求解吧;
首先,我们要进行两次dfs,第一次dfs将他的fa、deep、size、son//求的便是英文字母的字面意思
话不多说,上代码;
void dfs1(long long x)//当前所在节点,开始及是根
{
si[x]=1;//将节点的子树大小初始化为1,及自己
for(long long i=0;i<q[x].size();++i)//遍历自己的儿子
{
if(q[x][i]==fa[x])//不为爹
continue;
dep[q[x][i]]=dep[x]+1;//儿子深度为爸爸深度加1
fa[q[x][i]]=x;//把儿子的爸爸找到
dfs1(q[x][i]);//继续dfs儿子
si[x]+=si[q[x][i]];//递归回来x的树大小便是儿子的树大小相加再加上自己的大小(
自己的大小便是1啊)}}
dfs1完了,就进行dfs2,dfs2有什么用呢?
即将树按照重链重新编个号(我代码定义的pos),还要与此同时找到每一个点的链顶(我代码定义的top)
代码如下:
void dfs2(long long x,long long grfa)//x为当前节点,grfa及当前节点的链顶
{
long long k=0;//判断重链的计数器
sz++;//当前加到的节号
pos[x]=sz;
top[x]=grfa;
for(long long i=0;i<q[x].size();++i)//遍历儿子
{
if(dep[x]<dep[q[x][i]] and si[q[x][i]]>si[k])//找重儿子
{
k=q[x][i];
}
}
if(k==0)//没有儿子
return;
dfs2(k,grfa);//先递归把重链找到
for(long long i=0;i<q[x].size();++i)//遍历不是重儿子的其它儿子
{
if(dep[x]<dep[q[x][i]] and k!=q[x][i])
dfs2(q[x][i],q[x][i]);
}
}
我们将预处理搞好了,接下来便是用线段树来维这一棵树(以下是线段树的模板,我就不详细讲了):
struct node{long long l,r,ls,rs,c,f;}tr[maxn];
//l,r为区间,ls左节点,rs右节点,c是data,f是lazy标记
long long nw(long long l,long long r)//蒟蒻的laji线段树//新开区间为l到r的节点
{
tr[++t]=(node){l,r,0,0,0,0};
return t;//返回节点号
}
void xf(long long o)//下放lazy标记
{
long long l=tr[o].ls,r=tr[o].rs;
tr[o].c+=(tr[o].r-tr[o].l+1)*tr[o].f%mo;
tr[l].f+=tr[o].f;
tr[r].f+=tr[o].f;
tr[o].f=0;
}
void gx(long long o){//将节点data递归会来
long long l=tr[o].ls,r=tr[o].rs;
xf(l),xf(r);
tr[o].c=(tr[l].c+tr[r].c)%mo;
}
void xg(long long o,long long l,long long r,long long c)//将l到r区间加上c
{
xf(o);
if(tr[o].l==l&&tr[o].r==r)
{
tr[o].f+=c;
return;
}
long long mid=tr[o].l+tr[o].r>>1;
if(r<=mid)
xg(tr[o].ls?tr[o].ls:tr[o].ls=nw(tr[o].l,mid),l,r,c);
else
if(l>mid)
xg(tr[o].rs?tr[o].rs:tr[o].rs=nw(mid+1,tr[o].r),l,r,c);
else
xg(tr[o].ls,l,mid,c),xg(tr[o].rs,mid+1,r,c);
gx(o);
}
long long cx(long long o,long long l,long long r)//求l到r的值
{
xf(o);
if(tr[o].l==l&&tr[o].r==r)
return tr[o].c%mo;
long long mid=tr[o].l+tr[o].r>>1;
if(r<=mid)
return cx(tr[o].ls,l,r);
else
if(l>mid)
return cx(tr[o].rs,l,r);
else
return (cx(tr[o].ls,l,mid)+cx(tr[o].rs,mid+1,r))%mo;
}
下面代码便是如何将树和线段树结合其来:
long long dianhe(long long x,long long y)//求两点间的距离
{
long long ans=0;
while(top[x]!=top[y])//如果两点不在一条链上
{
if(dep[top[x]]<dep[top[y]])//保证x的链顶比y的链顶深
swap(x,y);
ans+=cx(1,pos[top[x]],pos[x]);//直接将x到x的链顶的距离用线段树求和
x=fa[top[x]];//将x跳到链顶的爸爸
}
if(dep[x]<dep[y])//当到一条链上,保证x的深度>y
swap(x,y);
ans+=cx(1,pos[y],pos[x]);//直接用线段树查询x到y的距离
return ans%mo;
}
void dianchage(long long x,long long y,long long z)//将x和y两点之间加上z
{
while(top[x]!=top[y])//就不多说了,和上面一样,将cx改为修改及可
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
xg(1,pos[top[x]],pos[x],z);
x=fa[top[x]];
}
if(dep[x]<dep[y])
swap(x,y);
xg(1,pos[y],pos[x],z);
return;
}
long long shuhe(long long x)//求树的和
{
return cx(1,pos[x],pos[x]+si[x]-1);//因为在树上树儿子编号是连续的,直接区间求和及可
}
void shuchage(long long x,long long y)//将子树每个点加上y
{
xg(1,pos[x],pos[x]+si[x]-1,y);//同上
return;
}
如果将上面的看明白了,树剖就差不多入门了(蒟蒻也不可能给你们讲什么高级的东西)
下面我就直接上完整代码吧
#include<bits/stdc++.h>
using namespace std;
const long long maxn=400100;
struct node
{
long long l,r,ls,rs,c,f;
}tr[maxn];
long long n,m,root,mo,sz;
long long a[maxn],si[maxn],dep[maxn],fa[maxn],pos[maxn],top[maxn];
vector<long long>q[maxn];
long long t;
long long nw(long long l,long long r)
{
tr[++t]=(node){l,r,0,0,0,0};
return t;
}
void xf(long long o)
{
long long l=tr[o].ls,r=tr[o].rs;
tr[o].c+=(tr[o].r-tr[o].l+1)*tr[o].f%mo;
tr[l].f+=tr[o].f;
tr[r].f+=tr[o].f;
tr[o].f=0;
}
void gx(long long o){
long long l=tr[o].ls,r=tr[o].rs;
xf(l),xf(r);
tr[o].c=(tr[l].c+tr[r].c)%mo;
}
void xg(long long o,long long l,long long r,long long c)
{
xf(o);
if(tr[o].l==l&&tr[o].r==r)
{
tr[o].f+=c;
return;
}
long long mid=tr[o].l+tr[o].r>>1;
if(r<=mid)
xg(tr[o].ls?tr[o].ls:tr[o].ls=nw(tr[o].l,mid),l,r,c);
else
if(l>mid)
xg(tr[o].rs?tr[o].rs:tr[o].rs=nw(mid+1,tr[o].r),l,r,c);
else
xg(tr[o].ls,l,mid,c),xg(tr[o].rs,mid+1,r,c);
gx(o);
}
long long cx(long long o,long long l,long long r)
{
xf(o);
if(tr[o].l==l&&tr[o].r==r)
return tr[o].c%mo;
long long mid=tr[o].l+tr[o].r>>1;
if(r<=mid)
return cx(tr[o].ls,l,r);
else
if(l>mid)
return cx(tr[o].rs,l,r);
else
return (cx(tr[o].ls,l,mid)+cx(tr[o].rs,mid+1,r))%mo;
}
void dfs1(long long x)
{
si[x]=1;
for(long long i=0;i<q[x].size();++i)
{
if(q[x][i]==fa[x])
continue;
dep[q[x][i]]=dep[x]+1;
fa[q[x][i]]=x;
dfs1(q[x][i]);
si[x]+=si[q[x][i]];
}
}
void dfs2(long long x,long long grfa)
{
long long k=0;
sz++;
pos[x]=sz;
top[x]=grfa;
for(long long i=0;i<q[x].size();++i)
{
if(dep[x]<dep[q[x][i]] and si[q[x][i]]>si[k])
{
k=q[x][i];
}
}
if(k==0)
return;
dfs2(k,grfa);
for(long long i=0;i<q[x].size();++i)
{
if(dep[x]<dep[q[x][i]] and k!=q[x][i])
dfs2(q[x][i],q[x][i]);
}
}
long long dianhe(long long x,long long y)
{
long long ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
ans+=cx(1,pos[top[x]],pos[x]);
x=fa[top[x]];
}
if(dep[x]<dep[y])
swap(x,y);
ans+=cx(1,pos[y],pos[x]);
return ans%mo;
}
void dianchage(long long x,long long y,long long z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
xg(1,pos[top[x]],pos[x],z);
x=fa[top[x]];
}
if(dep[x]<dep[y])
swap(x,y);
xg(1,pos[y],pos[x],z);
return;
}
long long shuhe(long long x)
{
return cx(1,pos[x],pos[x]+si[x]-1);
}
void shuchage(long long x,long long y)
{
xg(1,pos[x],pos[x]+si[x]-1,y);
return;
}
int main()
{
scanf("%lld%lld%lld%lld",&n,&m,&root,&mo);
for(long long i=1;i<=n;++i)
{
scanf("%d",&a[i]);
}
long long x,y;
nw(1,n);//建节点
for(long long i=1;i<n;++i)
{
scanf("%lld%lld",&x,&y);
q[y].push_back(x);
q[x].push_back(y);
}
dfs1(root);//预处理
dfs2(root,root);
for(long long i=1;i<=n;++i)
{
xg(1,pos[i],pos[i],a[i]);
}
long long k,h1,h2,h3;
while(m)
{
m--;
scanf("%lld",&k);
if(k==1)
{
scanf("%lld%lld%lld",&h1,&h2,&h3);
dianchage(h1,h2,h3);
}
if(k==2)
{
scanf("%lld%lld",&h1,&h2);
printf("%lld\n",dianhe(h1,h2)%mo);
}
if(k==3)
{
scanf("%lld%lld",&h1,&h2);
shuchage(h1,h2);
}
if(k==4)
{
scanf("%lld",&h1);
printf("%lld\n",shuhe(h1)%mo);
}
}
return 0;
}