浅谈树链剖分
r1sing
·
·
算法·理论
处理的问题
在一棵树上,让你实现一些路径上的修改查询问题,那么很有可能树链剖分可以解决。
树链剖分的思想
因为一些问题,如果放在一条链上可以使用一些数据结构来维护(树状数组、线段树等),树链剖分就是把一颗树通过某些方法分成若干条链,来优化处理问题的时间复杂度。
树链剖分的定义
对于一个非叶节点,我们把它的子树大小最大的子节点称为重子节点,其它子节点称为轻子节点。连接重子节点的边叫重边,连接轻子节点的边叫轻边。又若干条首尾相接重边所组成的链叫重链,特别的,对于是叶子的轻子节点,其自己也组成了一条重链。根据定义可以发现任意一棵树都可被剖分成若干条重链。
树链剖分所需信息
对于树链剖分,我们需要知道任意节点的以下信息,来进行剖分:\
$top(x)$ 节点 $x$ 所在的重链的顶点\
$dep(x)$ 节点 $x$ 的深度\
$son(x)$ 节点 $x$ 的重子节点\
$siz(x)$ 节点 $x$ 的子树大小\
$dfn(x)$ 节点 $x$ 的 dfn 编号\
$rnk(x)$ dfn 编号为 x 的节点
## 如何去求得这些信息
首先可以发现,这些信息都可以通过 dfs 来求得。可是如果每个信息都做一次 dfs 的话,太麻烦了。但是这里一些信息有依赖的关系(要先知道 $son(x)$ 才能求出 $dfn(x)$ 等)\
我们整理一下信息可以发现,两次 dfs 即可求得所有信息。\
第一次 dfs 求 $fa(x), dep(x), son(x), siz(x)$。\
第二次 dfs 求 $top(x), dfn(x), rnk(x)$。
## 树链剖分性质
树上每个节点属于且仅属于一条重链
重链上的点的 dfn 编号是连续的
对于树上任意一条路径,经过的重链条数是 $\log(n)$ 级别的
## 模板
[模板题传送门](https://www.luogu.com.cn/problem/P3384)
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
ll n, m, r, p;
ll top[100005], son[100005], fa[100005], dfn[100005], rk[100005], dep[100005], siz[100005];
vector <ll> adj[100005];
ll w[100005];
ll dt;
struct segtree
{
ll lson, rson;
ll sum, add;
}tree[200005];
ll cnt, root;
void change(ll &k, ll v, ll l, ll r)
{
if(!k)
k = ++cnt;
tree[k].sum += v * (r - l + 1) % p;
tree[k].sum %= p;
tree[k].add += v;
tree[k].add %= p;
return ;
}
void pushdown(ll k, ll l, ll r)
{
ll mid = l + r >> 1;
change(tree[k].lson, tree[k].add, l, mid);
change(tree[k].rson, tree[k].add, mid + 1, r);
tree[k].add = 0;
}
void update(ll x, ll y, ll z, ll l, ll r, ll &k)
{
if(!k)
k = ++cnt;
if(x <= l && r <= y)
{
change(k, z, l, r);
return ;
}
pushdown(k, l, r);
ll mid = l + r >> 1;
if(x <= mid)
update(x, y, z, l, mid, tree[k].lson);
if(y > mid)
update(x, y, z, mid + 1, r, tree[k].rson);
tree[k].sum = tree[tree[k].lson].sum + tree[tree[k].rson].sum;
return ;
}
ll query(ll x, ll y, ll l, ll r, ll &k)
{
if(!k)
return 0;
if(x <= l && r <= y)
return tree[k].sum % p;
pushdown(k, l, r);
ll mid = l + r >> 1, sum = 0;
if(x <= mid)
sum += query(x, y, l, mid, tree[k].lson), sum %= p;
if(y > mid)
sum += query(x, y, mid + 1, r, tree[k].rson), sum %= p;
return sum % p;
}
void dfs1(ll u, ll f) // calc fa dep siz son
{
siz[u] = 1;
fa[u] = f;
ll mss = 0;
for(ll it:adj[u])
{
if(siz[it])
continue;
dep[it] = dep[u] + 1;
dfs1(it, u);
if(siz[it] > siz[mss])
mss = it;
siz[u] += siz[it];
}
son[u] = mss;
return ;
}
void dfs2(ll u, ll tp) // calc dfn rk top
{
dfn[u] = ++dt;
rk[dfn[u]] = u;
top[u] = tp;
if(son[u])
dfs2(son[u], tp);
for(ll it:adj[u])
{
if(it == fa[u] || it == son[u])
continue;
dfs2(it, it);
}
return ;
}
ll qry(ll u, ll v)
{
ll fu = top[u], fv = top[v], ans = 0;
while(fu != fv)
{
if(dep[fu] >= dep[fv])
{
ans += query(dfn[fu], dfn[u], 1, n, root);
ans %= p;
u = fa[fu], fu = top[u];
}
else
{
ans += query(dfn[fv], dfn[v], 1, n, root);
ans %= p;
v = fa[fv], fv = top[v];
}
}
if(dep[u] >= dep[v])
ans += query(dfn[v], dfn[u], 1, n, root);
else
ans += query(dfn[u], dfn[v], 1, n, root);
ans %= p;
return ans;
}
void upd(ll u, ll v, ll x)
{
ll fu = top[u], fv = top[v], ans = 0;
while(fu != fv)
{
if(dep[fu] >= dep[fv])
{
update(dfn[fu], dfn[u], x, 1, n, root);
u = fa[fu], fu = top[u];
}
else
{
update(dfn[fv], dfn[v], x, 1, n, root);
v = fa[fv], fv = top[v];
}
}
if(dep[u] >= dep[v])
update(dfn[v], dfn[u], x, 1, n, root);
else
update(dfn[u], dfn[v], x, 1, n, root);
return ;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> m >> r >> p;
for(int i = 1; i <= n; i++)
cin >> w[i];
for(int i = 1; i < n; i++)
{
ll u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs1(r, r);
dfs2(r, r);
for(int i = 1; i <= n; i++)
update(dfn[i], dfn[i], w[i], 1, n, root);
while(m--)
{
ll op, u, v, x;
cin >> op;
if(op == 1)
{
cin >> u >> v >> x;
upd(u, v, x);
}
if(op == 2)
{
cin >> u >> v;
cout << qry(u, v) << "\n";
}
if(op == 3)
{
cin >> u >> x;
update(dfn[u], dfn[u] + siz[u] - 1, x, 1, n, root);
}
if(op == 4)
{
cin >> u;
cout << query(dfn[u], dfn[u] + siz[u] - 1, 1, n, root) << "\n";
}
}
// cout << "\n\n----------------------\n\n";
// cout << "fa dep siz son top dfn rk\n\n";
// for(int i = 1; i <= n; i++)
// {
// cout << i << " : " << fa[i] << " " << dep[i] << " " << siz[i] << " " << son[i] << " " << top[i] << " " << dfn[i] << " " << rk[i] << "\n";
// }
return 0;
}
```