树链剖分
万弘
2019-02-26 17:23:46
## 树链剖分
目录:
1. 树与DFS
2. 轻重边
3. 树链剖分
4. 代码实现(C++)
前置技能:树,DFS,[线段树](https://www.luogu.org/blog/73142/ji-chu-xian-duan-shu)等数据结构
为方便,以[此题](https://www.luogu.org/problemnew/show/P3384)为例
题意:给一棵有点权的树,要支持四种操作
> 操作1: 格式: 1 x y z :将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y :求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z :将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x :求以x为根节点的子树内所有节点值之和
### 1.树与DFS
比如有这样一棵树:
![1](https://cdn.luogu.com.cn/upload/pic/53329.png)
(圈内的数字为先序遍历的时间戳)
`设dfn[i]表示i的时间戳`
则对于u,其子树中的所有v都有`dfn[u]<=dfn[v]`,令u维护
$$[dfn[u],\max_v dfn[v]]$$
就维护了整棵子树
就可以把它丢到线段树等数据结构上以单次操作$O(logn)$完成操作3&4了
### 2.轻重边
`设size[u]表示以u为根(包含u)的子树大小`
> 重儿子:u的儿子中使size[v]最大的v
重边:连接u与重儿子的边
轻边:连接u与其非重儿子的边
重链:连续的重边组成的链
![](https://cdn.luogu.com.cn/upload/pic/53330.png)
(图中粗的为重边,轻的为轻边)
### 3.树链剖分
如上图,如果我们要查询3到7的链的和,就得暴力,单次最坏$O(n)$
这时,只要我们预处理出每个u的链顶top[u],就有:
- u到top[u]是dfs序上连续的一段,可以用线段树直接$O(logn)$地加上这一段
- **对于任意u->v的链,经过的重链与轻边数的和不超过$O(logn)$**
所以在查询u->v的和时
- 若u,v在同一条重链上,直接`add(dfn[u],dfn[v])`(不过要保证dfn[u]<dfn[v])
- 否则,**让dep[top[]]更深为i**,进行`add(dfn[top[i]],dfn[i]),i=fa[top[i]]`,让i跳一条重链或者轻边
**注意,是dep[top[]]深的先跳,不是dep[]深的先跳**
比如 ![](https://cdn.luogu.com.cn/upload/pic/53330.png)
解释:如果查询4到6的和,让dep深的先跳,4跳到1,6跳到2,再进行add(1,2)
最后4,3,2,1,6,2,1被查询
但显然应该是4,3,2,6被查询
所以树链剖分对1&2操作单次复杂度$O(log^2n)$
树链剖分的预处理分为两次dfs:
1. 第一次dfs记录父亲,size[],重儿子,深度
2. 从根节点开始沿着重儿子往下拉重链,不在当前重链上的点,再以其为起点重新拉重链,同时记录top[],dfn[]
### 4.代码实现(add,query都应该访问dfn而不是节点编号,tcp::t[]表示dfn[])
```cpp
#include<iostream>
#include<cstdio>
typedef long long ll;
typedef unsigned un;
#define INF (1ll<<50)
ll read()
{
char c;
ll f=1,x=0;
do
{
c=getchar();
if(c=='-')f=-1;
}while(c<'0'||c>'9');
do
{
x=x*10+c-'0';
c=getchar();
}while(c>='0'&&c<='9');
return f*x;
}
void write(ll x)
{
if(x==0)
{
putchar('0');
putchar(' ');
return;
}
if(x<0)
{
putchar('-');
x=-x;
}
ll s[21],top=0;
while(x)
{
s[++top]=x%10;
x/=10;
}
for(ll i=top;i>=1;--i)putchar(s[i]+'0');
putchar(' ');
}
#define maxn 100010
ll n,m,r,p;
struct Edge
{
ll v,nxt;
}e[maxn<<2|1];
ll last[maxn],cnt=0;
void adde(ll u,ll v)
{
e[++cnt].v=v;
e[cnt].nxt=last[u];last[u]=cnt;
}
struct Segment_Tree//线段树
{
struct node
{
ll v,tag;
}t[maxn<<2|1];
#define lson num<<1
#define rson num<<1|1
#define rt t[num]
#define tl t[num<<1]
#define tr t[num<<1|1]
void pushup(un num)
{
rt.v=(tl.v+tr.v)%p;
}
void pushdown(un l,un r,un num)
{
if(!rt.tag)return;
un mid=(l+r)>>1;
tl.v=(tl.v+(mid-l+1)*rt.tag)%p;
tl.tag=(tl.tag+rt.tag)%p;
tr.v=(tr.v+(r-mid)*rt.tag)%p;
tr.tag=(tr.tag+rt.tag)%p;
rt.tag=0;
}
void build(ll *a,un l=1,un r=n,un num=1)
{
if(l==r)rt.v=a[l];
else
{
un mid=(l+r)>>1;
build(a,l,mid,lson);build(a,mid+1,r,rson);
pushup(num);
}
rt.tag=0;
}
void add(un ql,un qr,ll k,un l=1,un r=n,un num=1)//[ql,qr]+=k;
{
if(ql<=l&&r<=qr)
{
rt.v=(rt.v+(r-l+1)*k)%p;
rt.tag=(rt.tag+k)%p;
return;
}
if(l>qr||r<ql)return;
un mid=(l+r)>>1;
pushdown(l,r,num);
add(ql,qr,k,l,mid,lson);add(ql,qr,k,mid+1,r,rson);
pushup(num);
}
ll sum(un ql,un qr,un l=1,un r=n,un num=1)//\sum [ql,qr]
{
if(ql<=l&&r<=qr)return rt.v;
if(l>qr||r<ql)return 0;
ll mid=(l+r)>>1;
pushdown(l,r,num);
return (sum(ql,qr,l,mid,lson)+sum(ql,qr,mid+1,r,rson))%p;
}
}sgt;
namespace tcp//树链剖分
{
ll fa[maxn],son[maxn],dep[maxn],size[maxn];
ll top[maxn],a[maxn],t[maxn],tot=0;
void dfs1(ll u)
{
size[u]=1;
for(ll i=last[u];i;i=e[i].nxt)
{
ll v=e[i].v;
if(v==fa[u])continue;
fa[v]=u;dep[v]=dep[u]+1;
dfs1(v);
size[u]+=size[v];
if(size[v]>size[son[u]])son[u]=v;
}
}
void dfs2(ll *w,ll u,ll utop)
{
t[u]=++tot;top[u]=utop;a[tot]=w[u];
if(son[u])dfs2(w,son[u],utop);
for(ll i=last[u];i;i=e[i].nxt)
{
ll v=e[i].v;
if(v==fa[u]||v==son[u])continue;
dfs2(w,v,v);
}
}
void build(ll *w,ll r)
{
dfs1(r);
dfs2(w,r,r);
sgt.build(a);
}
void add(ll u,ll v,ll k)
{
while(top[u]!=top[v])
{
if(dep[top[u]]>dep[top[v]])
{
sgt.add(t[top[u]],t[u],k);
u=fa[top[u]];
}
else
{
sgt.add(t[top[v]],t[v],k);
v=fa[top[v]];
}
}
if(dep[u]<dep[v])sgt.add(t[u],t[v],k);
else sgt.add(t[v],t[u],k);
}
ll sum(ll u,ll v)
{
ll res=0;
while(top[u]!=top[v])
{
if(dep[top[u]]>dep[top[v]])
{
res=(res+sgt.sum(t[top[u]],t[u]))%p;
u=fa[top[u]];
}
else
{
res=(res+sgt.sum(t[top[v]],t[v]))%p;
v=fa[top[v]];
}
}
if(dep[u]<dep[v])return (res+sgt.sum(t[u],t[v]))%p;
else return (res+sgt.sum(t[v],t[u]))%p;
}
void adt(ll u,ll k)
{
sgt.add(t[u],t[u]+size[u]-1,k);
}
ll qt(ll u)
{
return sgt.sum(t[u],t[u]+size[u]-1);
}
}
ll w[maxn];
int main()
{
n=read(),m=read(),r=read(),p=read();
for(ll i=1;i<=n;++i)w[i]=read();
for(ll i=1;i<n;++i)
{
ll u=read(),v=read();
adde(u,v);
adde(v,u);
}
tcp::build(w,r);
for(ll i=1;i<=m;++i)
{
ll op=read(),x,y,z;
if(op==1)
{
x=read(),y=read(),z=read();
tcp::add(x,y,z);
}
else if(op==2)
{
x=read(),y=read();
printf("%lld\n",tcp::sum(x,y));
}
else if(op==3)
{
x=read(),z=read();
tcp::adt(x,z);
}
else
{
x=read();
printf("%lld\n",tcp::qt(x));
}
}
return 0;
}
```
没什么注释,有疑问私信我