树链剖分模板
Star_Cried · · 个人记录
看到有很多大佬写得都很好,我也是学习他们的,但是线段树用数组记录看着太难受了啊喂(╯‵□′)╯︵┻━┻
所以我就写一个结构体的线段树以备不同习惯同学之需,顺便写一下自己的理解。
我会告诉你我不小心把%写成了*调了一个多小时吗
适用范围
在一棵树上+-点权/边权然后多次提问的问题等
原理
将一棵树剖分成若干条链,在链上通过数据结构维护。
照本宣科
将一个点的子树中最大的那个做为重儿子,其他的叫做轻儿子,对于所有点的重儿子连的链称为重链。我们将重链作为主链,将轻儿子构成的其他轻链加在重链的后面。很明显,对于每一条链,它的节点都是相连的(废话)。对于每一个子树,它的所有的节点都排在子树根的后面。
所以我们就把一个树成功地退化成了一条链。 因为这条链有如上我所说的几个性质,题目要求如果是能在链上连续维护的(如求子树权值和,求任意两点间距离和),我们就可以用数据结构维护它了,比如线段树。
构造
首先我们需要两次dfs。
dfs1
用于建树,顺便记录每个节点的父亲 和 该点的深度 和 它的子树的大小 和 它的重儿子。
fa[]父亲节点 dep[]深度 siz[]子树大小 son[]重儿子(重儿子为子树大者)
void dfs1(int x,int f,int depth)
{
siz[x]=1; dep[x]=depth; fa[x]=f;
int maxson=-1;
for(R int i=head[x];i;i=e[i].nxt)
{
int u=e[i].to;
if(u==f)continue;
dfs1(u,x,depth+1);
siz[x]+=siz[u];
if(siz[u]>maxson)son[x]=u,maxson=siz[u];
}
}
dfs2
用于将一个树退化成链,记录节点在链的编号 和 节点的链首节点 和 链上节点的权值。 cnt时间戳 id[]编号 top[]链首节点 a[]原值 w[]链上的值(便于维护)
void dfs2(int x,int topf)
{
id[x]=++cnt;
w[cnt]=a[x];
top[x]=topf;
if(!son[x])return;
dfs2(son[x],topf);
for(R int i=head[x];i;i=e[i].nxt)
{
int u=e[i].to;
if(u==fa[x]||u==son[x])continue;
dfs2(u,u);
}
}
求两点间距离
要求两点间距离:
若两点在一条链上(top[x]==top[y])我们直接求两点间的距离即可。
若两点不在一条链上,那么求更深的那个点x到此刻链首top[x]的距离,然后令x=fa[top[x]],即可将x更新到新链上。重复操作,每次只将深度更大的点向上更新。最终两点会处于同一条链(重链或轻链,最远是重链)上,然后再加上两点间的和就可以了。
用线段树维护。
inline int queryrange(int x,int y)
{
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=0;
st.query(1,id[top[x]],id[x]);
ans=(ans+res)%mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
res=0;
st.query(1,id[x],id[y]);
ans=(ans+res)%mod;
return ans;
}
更新两点间距离
和上面是一样的,分成若干条链更新就可以了。
inline void updaterange(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
st.update(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
st.update(1,id[x],id[y],k);
}
更新/查询子树权值(和)
如上,子树节点在链上一定是在根的后面并且连续的。
所以要更新以x为根节点的所有子树结点,就更新id[x]~id[x]+size[x]-1的范围即可。
(代码放在代码里)
当然要根据题目需要做各种修改,这真的只是一个板子。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<ctype.h>
#define R register
using namespace std;
inline int read()
{
int x=0,w=0;char c=getchar();
while(!isdigit(c))w|=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
return w?-x:x;
}
const int maxn=1e5+5,maxm=1e5+5;
int n,m,root,mod,fa[maxn],dep[maxn],son[maxn],siz[maxn],top[maxn],id[maxn],a[maxn],w[maxn];
struct Edge{
int to,nxt;
}e[maxn*2];
int ecnt,head[maxn];
inline void addedge(int from,int to)
{
e[++ecnt]=(Edge){to,head[from]};head[from]=ecnt;
e[++ecnt]=(Edge){from,head[to]};head[to]=ecnt;
}
int res;
struct SegmentTree{
#define ls (ro<<1)
#define rs (ro<<1|1)
struct tree{
int l,r,val,tag;
}e[maxn<<2];
inline void push_up(int ro)
{
e[ro].val=(e[ls].val+e[rs].val)%mod;
}
inline void push_down(int ro)
{
e[ls].tag+=e[ro].tag;
e[rs].tag+=e[ro].tag;
e[ls].val+=e[ro].tag*(e[ls].r-e[ls].l+1);
e[rs].val+=e[ro].tag*(e[rs].r-e[rs].l+1);
e[ls].val%=mod;
e[rs].val%=mod;
e[ro].tag=0;
}
void build(int ro,int l,int r)
{
e[ro].l=l;e[ro].r=r;
if(l==r){e[ro].val=w[l]%mod;return ;}
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
push_up(ro);
}
void update(int ro,int l,int r,int k)
{
if(e[ro].l>=l and e[ro].r<=r){
e[ro].tag+=k;
e[ro].val+=k*(e[ro].r-e[ro].l+1);
e[ro].val%=mod;
return;
}
int mid=(e[ro].l+e[ro].r)>>1;
if(e[ro].tag)push_down(ro);
if(l<=mid)update(ls,l,r,k);
if(r>mid)update(rs,l,r,k);
push_up(ro);
}
void query(int ro,int l,int r)
{
if(e[ro].l>=l and e[ro].r<=r){
res+=e[ro].val;res%=mod;return;
}
if(e[ro].tag)push_down(ro);
int mid=(e[ro].l+e[ro].r)>>1;
if(l<=mid)query(ls,l,r);
if(r>mid)query(rs,l,r);
}
#undef ls
#undef rs
}st;
void dfs1(int x,int f,int depth)
{
siz[x]=1; dep[x]=depth; fa[x]=f;
int maxson=-1;
for(R int i=head[x];i;i=e[i].nxt)
{
int u=e[i].to;
if(u==f)continue;
dfs1(u,x,depth+1);
siz[x]+=siz[u];
if(siz[u]>maxson)son[x]=u,maxson=siz[u];
}
}
int cnt;
void dfs2(int x,int topf)
{
id[x]=++cnt;
w[cnt]=a[x];
top[x]=topf;
if(!son[x])return;
dfs2(son[x],topf);
for(R int i=head[x];i;i=e[i].nxt)
{
int u=e[i].to;
if(u==fa[x]||u==son[x])continue;
dfs2(u,u);
}
}
inline void updaterange(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
st.update(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
st.update(1,id[x],id[y],k);
}
inline int queryrange(int x,int y)
{
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=0;
st.query(1,id[top[x]],id[x]);
ans=(ans+res)%mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
res=0;
st.query(1,id[x],id[y]);
ans=(ans+res)%mod;
return ans;
}
inline void debug()
{
for(int i=1;i<=n;i++)
{
res=0;st.query(1,id[i],id[i]);
printf("%d ",res);
}
}
int main()
{
n=read(),m=read(),root=read(),mod=read();
for(R int i=1;i<=n;i++)a[i]=read();
for(R int i=1;i<n;i++)addedge(read(),read());
dfs1(root,0,1);
dfs2(root,root);
st.build(1,1,n);
while(m--)
{
int a,b,c;
switch(read())
{
case(1):{
a=read(),b=read(),c=read();
updaterange(a,b,c);
break;
}
case(2):{
a=read(),b=read();
printf("%d\n",queryrange(a,b));
break;
}
case(3):{
int a=read(),b=read();
st.update(1,id[a],id[a]+siz[a]-1,b);
break;
}
case(4):{
int a=read();res=0;
st.query(1,id[a],id[a]+siz[a]-1);
printf("%d\n",res);
break;
}
}
}
//debug();
return 0;
}