树链剖分
枫林晚
2018-05-06 20:48:35
### 意义:
树链剖分 就是对一棵树分成几条链,把树形变为线性,减少处理难度
概念
重儿子:对于每一个非叶子节点,它的儿子中 儿子数量最多的那一个儿子 为该节点的重儿子
轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。)
重边:连接任意两个重儿子的边叫做重边
轻边:剩下的即为轻边
重链:相邻重边连起来的 连接一条重儿子 的链叫重链
对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
每一条重链以轻儿子为起点
### 题目大意:
给定一棵有根树,给定每个点初值。
需要处理的问题:
将树从x到y结点最短路径上所有节点的值都加上z
求树从x到y结点最短路径上所有节点的值之和
将以x为根节点的子树内所有节点值都加上z
求以x为根节点的子树内所有节点值之和
分析:
树链剖分+线段树
树剖部分:
需要数组:
```cpp
int root,n,m,p;
int dfn[N],dfn2[N],fdfn[N];
int top[N],son[N],fa[N],dep[N],size[N];
```
1.dfs1:
目标:
①找到fa,重儿子(son)
②处理节点深度,子树大小(size)(dep[root]=1,fa[root]=-1,其实本题不固定)
```cpp
void dfs1(int x,int f,int d)
{
dep[x]=d;
size[x]=1;
int mx=0;
for(int i=head[x];i;i=bian[i].nxt)
{
int y=bian[i].to;
if(y==f) continue;//不能回走
fa[y]=x;
dfs1(y,x,d+1);
size[x]+=size[y];
if(size[y]>mx)
{
mx=size[y],son[x]=y;//记录重儿子
}
}
}
```
2.dfs2
目标:
①找到dfn,dfn2(子树结尾dfn)便于之后线段树维护区间。
②处理fdfn,记录dfnx是几号点。便于线段树build
③注意:有重儿子,先走重儿子。
结果:
dfn数组中,一棵完整的子树,其dfn也是连续的一段。每条重链也是连续的一段。这样,用线段树很方便维护树上路径的处理。
```cpp
void dfs2(int x,int f)
{
dfn[x]=++tot;
fdfn[tot]=x;//第tot个dfn是x号
if(!top[x]) top[x]=x;//top未赋值才能赋值
if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);//先走重儿子
for(int i=head[x];i;i=bian[i].nxt)
{
int y=bian[i].to;
if(y==son[x]||y==f) continue;
dfs2(y,x);
}
dfn2[x]=tot;//回溯之前记录下子树结尾dfn
}
```
此处省去线段树常规操作,详见下面代码。
3.work1
利用树剖lca想法,其中一个点一边向上翻的同时,更新值。最后在同一条链上了之后,相当于已经找到了lca直接更新另一条路径。
```cpp
void work1(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]]) swap(x,y);//dep[top]深度深的向上翻
add(1,1,tot,dfn[top[y]],dfn[y],z);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
add(1,1,tot,dfn[x],dfn[y],z);//另一边路径
}
```
work2同理。
4.work3,work4,利用之前记录过的dfn2,可以直接找到子树区间。直接处理即可。
```cpp
void work3(int x,int z)
{
add(1,1,tot,dfn[x],dfn2[x],z);
}
int work4(int x)
{
int sum=0;
sum=(sum+query(1,1,tot,dfn[x],dfn2[x]))%p;
return sum;
}
```
### 注意事项:
1.每次dfs注意不要返祖。
2.记得取模!!!任何加减,赋值,求和都要提起注意。
3.区间add标记直接加,sum要+c×(len)必须乘区间!!(线段树不过关。。。)
4.root是原来树的根,线段树的根就是1!!(不要混了)RE无数无数无数
详见代码:
```cpp
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int a[N];
int root,n,m,p;
int dfn[N],dfn2[N],fdfn[N];
int top[N],son[N],fa[N],dep[N],size[N];
struct node{
int nxt,to;
}bian[2*N];
int cnt,tot;
int head[N];
void add(int x,int y)
{
bian[++cnt].nxt=head[x];
bian[cnt].to=y;
head[x]=cnt;
}
void dfs1(int x,int f,int d)
{
dep[x]=d;
size[x]=1;
int mx=0;
for(int i=head[x];i;i=bian[i].nxt)
{
int y=bian[i].to;
if(y==f) continue;
fa[y]=x;
dfs1(y,x,d+1);
size[x]+=size[y];
if(size[y]>mx)
{
mx=size[y],son[x]=y;
}
}
}
void dfs2(int x,int f)
{
dfn[x]=++tot;
fdfn[tot]=x;
if(!top[x]) top[x]=x;
if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);
for(int i=head[x];i;i=bian[i].nxt)
{
int y=bian[i].to;
if(y==son[x]||y==f) continue;
dfs2(y,x);
}
dfn2[x]=tot;
}
//-------------------以上树剖 -----------------------------------
int mod(int x)
{
while(x>=p) x-=p;
while(x<0) x+=p;
return x;
}
struct tree{
int sum,add;
#define s(x) t[x].sum
#define ad(x) t[x].add
}t[4*N];
void pushup(int x)
{
s(x)=mod(s(x<<1)+s(x<<1|1));
}
void build(int x,int l,int r)
{
if(l==r)
{
s(x)=mod(a[fdfn[l]]);ad(x)=0;
return;
}
int mid=(l+r)>>1;
build(x<<1,l,mid);
build(x<<1|1,mid+1,r);
pushup(x);
}
void pushdown(int x,int l,int r)//change sum+=ad*len
{
int s1=x<<1,s2=x<<1|1;
int mid=(l+r)>>1;
ad(s1)=mod(ad(s1)+ad(x));
s(s1)=mod(s(s1)+ad(x)*(mid-l+1));
ad(s2)=mod(ad(s2)+ad(x));
s(s2)=mod(s(s2)+ad(x)*(r-mid));
ad(x)=0;
}
void add(int x,int l,int r,int L,int R,int c)
{
if(L<=l&&r<=R)
{
s(x)=mod(s(x)+mod(c*(r-l+1)));
ad(x)=mod(ad(x)+c);
return;
}
pushdown(x,l,r);
int mid=(l+r)>>1;
if(L<=mid) add(x<<1,l,mid,L,R,c);
if(mid<R) add(x<<1|1,mid+1,r,L,R,c);
pushup(x);
}
int query(int x,int l,int r,int L,int R)
{
if(L<=l&&r<=R)
{
return s(x);
}
pushdown(x,l,r);
int mid=(l+r)>>1;
int res=0;
if(L<=mid) res=mod(res+query(x<<1,l,mid,L,R));
if(mid<R) res=mod(res+query(x<<1|1,mid+1,r,L,R));
return res;
}
//-------------------以上线段树 -----------------------------------
void work1(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]]) swap(x,y);
add(1,1,tot,dfn[top[y]],dfn[y],z);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
add(1,1,tot,dfn[x],dfn[y],z);
}
int work2(int x,int y)
{
int sum=0;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]]) swap(x,y);
sum=(sum+query(1,1,tot,dfn[top[y]],dfn[y]))%p;
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
sum=(sum+query(1,1,tot,dfn[x],dfn[y]))%p;
return sum;
}
void work3(int x,int z)
{
add(1,1,tot,dfn[x],dfn2[x],z);
}
int work4(int x)
{
int sum=0;
sum=(sum+query(1,1,tot,dfn[x],dfn2[x]))%p;
return sum;
}
int main()
{
scanf("%d%d%d%d",&n,&m,&root,&p);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
int x,y;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(root,-1,1);
dfs2(root,-1);
fa[root]=-1;
build(1,1,tot);
int op,z;
while(m)
{
scanf("%d",&op);
if(op==1)
{
scanf("%d%d%d",&x,&y,&z);
work1(x,y,z);
}
else if(op==2)
{
scanf("%d%d",&x,&y);
printf("%d\n",work2(x,y));
}
else if(op==3)
{
scanf("%d%d",&x,&z);
work3(x,z);
}
else{
scanf("%d",&x);
printf("%d\n",work4(x));
}
m--;
}
return 0;
}
```