树链剖分

万弘

2019-02-26 17:23:46

Personal

## 树链剖分 目录: 1. 树与DFS 2. 轻重边 3. 树链剖分 4. 代码实现(C++) 前置技能:树,DFS,[线段树](https://www.luogu.org/blog/73142/ji-chu-xian-duan-shu)等数据结构 为方便,以[此题](https://www.luogu.org/problemnew/show/P3384)为例 题意:给一棵有点权的树,要支持四种操作 > 操作1: 格式: 1 x y z :将树从x到y结点最短路径上所有节点的值都加上z 操作2: 格式: 2 x y :求树从x到y结点最短路径上所有节点的值之和 操作3: 格式: 3 x z :将以x为根节点的子树内所有节点值都加上z 操作4: 格式: 4 x :求以x为根节点的子树内所有节点值之和 ### 1.树与DFS 比如有这样一棵树: ![1](https://cdn.luogu.com.cn/upload/pic/53329.png) (圈内的数字为先序遍历的时间戳) `设dfn[i]表示i的时间戳` 则对于u,其子树中的所有v都有`dfn[u]<=dfn[v]`,令u维护 $$[dfn[u],\max_v dfn[v]]$$ 就维护了整棵子树 就可以把它丢到线段树等数据结构上以单次操作$O(logn)$完成操作3&4了 ### 2.轻重边 `设size[u]表示以u为根(包含u)的子树大小` > 重儿子:u的儿子中使size[v]最大的v 重边:连接u与重儿子的边 轻边:连接u与其非重儿子的边 重链:连续的重边组成的链 ![](https://cdn.luogu.com.cn/upload/pic/53330.png) (图中粗的为重边,轻的为轻边) ### 3.树链剖分 如上图,如果我们要查询3到7的链的和,就得暴力,单次最坏$O(n)$ 这时,只要我们预处理出每个u的链顶top[u],就有: - u到top[u]是dfs序上连续的一段,可以用线段树直接$O(logn)$地加上这一段 - **对于任意u->v的链,经过的重链与轻边数的和不超过$O(logn)$** 所以在查询u->v的和时 - 若u,v在同一条重链上,直接`add(dfn[u],dfn[v])`(不过要保证dfn[u]<dfn[v]) - 否则,**让dep[top[]]更深为i**,进行`add(dfn[top[i]],dfn[i]),i=fa[top[i]]`,让i跳一条重链或者轻边 **注意,是dep[top[]]深的先跳,不是dep[]深的先跳** 比如 ![](https://cdn.luogu.com.cn/upload/pic/53330.png) 解释:如果查询4到6的和,让dep深的先跳,4跳到1,6跳到2,再进行add(1,2) 最后4,3,2,1,6,2,1被查询 但显然应该是4,3,2,6被查询 所以树链剖分对1&2操作单次复杂度$O(log^2n)$ 树链剖分的预处理分为两次dfs: 1. 第一次dfs记录父亲,size[],重儿子,深度 2. 从根节点开始沿着重儿子往下拉重链,不在当前重链上的点,再以其为起点重新拉重链,同时记录top[],dfn[] ### 4.代码实现(add,query都应该访问dfn而不是节点编号,tcp::t[]表示dfn[]) ```cpp #include<iostream> #include<cstdio> typedef long long ll; typedef unsigned un; #define INF (1ll<<50) ll read() { char c; ll f=1,x=0; do { c=getchar(); if(c=='-')f=-1; }while(c<'0'||c>'9'); do { x=x*10+c-'0'; c=getchar(); }while(c>='0'&&c<='9'); return f*x; } void write(ll x) { if(x==0) { putchar('0'); putchar(' '); return; } if(x<0) { putchar('-'); x=-x; } ll s[21],top=0; while(x) { s[++top]=x%10; x/=10; } for(ll i=top;i>=1;--i)putchar(s[i]+'0'); putchar(' '); } #define maxn 100010 ll n,m,r,p; struct Edge { ll v,nxt; }e[maxn<<2|1]; ll last[maxn],cnt=0; void adde(ll u,ll v) { e[++cnt].v=v; e[cnt].nxt=last[u];last[u]=cnt; } struct Segment_Tree//线段树 { struct node { ll v,tag; }t[maxn<<2|1]; #define lson num<<1 #define rson num<<1|1 #define rt t[num] #define tl t[num<<1] #define tr t[num<<1|1] void pushup(un num) { rt.v=(tl.v+tr.v)%p; } void pushdown(un l,un r,un num) { if(!rt.tag)return; un mid=(l+r)>>1; tl.v=(tl.v+(mid-l+1)*rt.tag)%p; tl.tag=(tl.tag+rt.tag)%p; tr.v=(tr.v+(r-mid)*rt.tag)%p; tr.tag=(tr.tag+rt.tag)%p; rt.tag=0; } void build(ll *a,un l=1,un r=n,un num=1) { if(l==r)rt.v=a[l]; else { un mid=(l+r)>>1; build(a,l,mid,lson);build(a,mid+1,r,rson); pushup(num); } rt.tag=0; } void add(un ql,un qr,ll k,un l=1,un r=n,un num=1)//[ql,qr]+=k; { if(ql<=l&&r<=qr) { rt.v=(rt.v+(r-l+1)*k)%p; rt.tag=(rt.tag+k)%p; return; } if(l>qr||r<ql)return; un mid=(l+r)>>1; pushdown(l,r,num); add(ql,qr,k,l,mid,lson);add(ql,qr,k,mid+1,r,rson); pushup(num); } ll sum(un ql,un qr,un l=1,un r=n,un num=1)//\sum [ql,qr] { if(ql<=l&&r<=qr)return rt.v; if(l>qr||r<ql)return 0; ll mid=(l+r)>>1; pushdown(l,r,num); return (sum(ql,qr,l,mid,lson)+sum(ql,qr,mid+1,r,rson))%p; } }sgt; namespace tcp//树链剖分 { ll fa[maxn],son[maxn],dep[maxn],size[maxn]; ll top[maxn],a[maxn],t[maxn],tot=0; void dfs1(ll u) { size[u]=1; for(ll i=last[u];i;i=e[i].nxt) { ll v=e[i].v; if(v==fa[u])continue; fa[v]=u;dep[v]=dep[u]+1; dfs1(v); size[u]+=size[v]; if(size[v]>size[son[u]])son[u]=v; } } void dfs2(ll *w,ll u,ll utop) { t[u]=++tot;top[u]=utop;a[tot]=w[u]; if(son[u])dfs2(w,son[u],utop); for(ll i=last[u];i;i=e[i].nxt) { ll v=e[i].v; if(v==fa[u]||v==son[u])continue; dfs2(w,v,v); } } void build(ll *w,ll r) { dfs1(r); dfs2(w,r,r); sgt.build(a); } void add(ll u,ll v,ll k) { while(top[u]!=top[v]) { if(dep[top[u]]>dep[top[v]]) { sgt.add(t[top[u]],t[u],k); u=fa[top[u]]; } else { sgt.add(t[top[v]],t[v],k); v=fa[top[v]]; } } if(dep[u]<dep[v])sgt.add(t[u],t[v],k); else sgt.add(t[v],t[u],k); } ll sum(ll u,ll v) { ll res=0; while(top[u]!=top[v]) { if(dep[top[u]]>dep[top[v]]) { res=(res+sgt.sum(t[top[u]],t[u]))%p; u=fa[top[u]]; } else { res=(res+sgt.sum(t[top[v]],t[v]))%p; v=fa[top[v]]; } } if(dep[u]<dep[v])return (res+sgt.sum(t[u],t[v]))%p; else return (res+sgt.sum(t[v],t[u]))%p; } void adt(ll u,ll k) { sgt.add(t[u],t[u]+size[u]-1,k); } ll qt(ll u) { return sgt.sum(t[u],t[u]+size[u]-1); } } ll w[maxn]; int main() { n=read(),m=read(),r=read(),p=read(); for(ll i=1;i<=n;++i)w[i]=read(); for(ll i=1;i<n;++i) { ll u=read(),v=read(); adde(u,v); adde(v,u); } tcp::build(w,r); for(ll i=1;i<=m;++i) { ll op=read(),x,y,z; if(op==1) { x=read(),y=read(),z=read(); tcp::add(x,y,z); } else if(op==2) { x=read(),y=read(); printf("%lld\n",tcp::sum(x,y)); } else if(op==3) { x=read(),z=read(); tcp::adt(x,z); } else { x=read(); printf("%lld\n",tcp::qt(x)); } } return 0; } ``` 没什么注释,有疑问私信我