浅谈树链剖分

· · 算法·理论

处理的问题

在一棵树上,让你实现一些路径上的修改查询问题,那么很有可能树链剖分可以解决。

树链剖分的思想

因为一些问题,如果放在一条链上可以使用一些数据结构来维护(树状数组、线段树等),树链剖分就是把一颗树通过某些方法分成若干条链,来优化处理问题的时间复杂度。

树链剖分的定义

对于一个非叶节点,我们把它的子树大小最大的子节点称为重子节点,其它子节点称为轻子节点。连接重子节点的边叫重边,连接轻子节点的边叫轻边。又若干条首尾相接重边所组成的链叫重链,特别的,对于是叶子的轻子节点,其自己也组成了一条重链。根据定义可以发现任意一棵树都可被剖分成若干条重链。

树链剖分所需信息

对于树链剖分,我们需要知道任意节点的以下信息,来进行剖分:\

$top(x)$ 节点 $x$ 所在的重链的顶点\ $dep(x)$ 节点 $x$ 的深度\ $son(x)$ 节点 $x$ 的重子节点\ $siz(x)$ 节点 $x$ 的子树大小\ $dfn(x)$ 节点 $x$ 的 dfn 编号\ $rnk(x)$ dfn 编号为 x 的节点 ## 如何去求得这些信息 首先可以发现,这些信息都可以通过 dfs 来求得。可是如果每个信息都做一次 dfs 的话,太麻烦了。但是这里一些信息有依赖的关系(要先知道 $son(x)$ 才能求出 $dfn(x)$ 等)\ 我们整理一下信息可以发现,两次 dfs 即可求得所有信息。\ 第一次 dfs 求 $fa(x), dep(x), son(x), siz(x)$。\ 第二次 dfs 求 $top(x), dfn(x), rnk(x)$。 ## 树链剖分性质 树上每个节点属于且仅属于一条重链 重链上的点的 dfn 编号是连续的 对于树上任意一条路径,经过的重链条数是 $\log(n)$ 级别的 ## 模板 [模板题传送门](https://www.luogu.com.cn/problem/P3384) ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; ll n, m, r, p; ll top[100005], son[100005], fa[100005], dfn[100005], rk[100005], dep[100005], siz[100005]; vector <ll> adj[100005]; ll w[100005]; ll dt; struct segtree { ll lson, rson; ll sum, add; }tree[200005]; ll cnt, root; void change(ll &k, ll v, ll l, ll r) { if(!k) k = ++cnt; tree[k].sum += v * (r - l + 1) % p; tree[k].sum %= p; tree[k].add += v; tree[k].add %= p; return ; } void pushdown(ll k, ll l, ll r) { ll mid = l + r >> 1; change(tree[k].lson, tree[k].add, l, mid); change(tree[k].rson, tree[k].add, mid + 1, r); tree[k].add = 0; } void update(ll x, ll y, ll z, ll l, ll r, ll &k) { if(!k) k = ++cnt; if(x <= l && r <= y) { change(k, z, l, r); return ; } pushdown(k, l, r); ll mid = l + r >> 1; if(x <= mid) update(x, y, z, l, mid, tree[k].lson); if(y > mid) update(x, y, z, mid + 1, r, tree[k].rson); tree[k].sum = tree[tree[k].lson].sum + tree[tree[k].rson].sum; return ; } ll query(ll x, ll y, ll l, ll r, ll &k) { if(!k) return 0; if(x <= l && r <= y) return tree[k].sum % p; pushdown(k, l, r); ll mid = l + r >> 1, sum = 0; if(x <= mid) sum += query(x, y, l, mid, tree[k].lson), sum %= p; if(y > mid) sum += query(x, y, mid + 1, r, tree[k].rson), sum %= p; return sum % p; } void dfs1(ll u, ll f) // calc fa dep siz son { siz[u] = 1; fa[u] = f; ll mss = 0; for(ll it:adj[u]) { if(siz[it]) continue; dep[it] = dep[u] + 1; dfs1(it, u); if(siz[it] > siz[mss]) mss = it; siz[u] += siz[it]; } son[u] = mss; return ; } void dfs2(ll u, ll tp) // calc dfn rk top { dfn[u] = ++dt; rk[dfn[u]] = u; top[u] = tp; if(son[u]) dfs2(son[u], tp); for(ll it:adj[u]) { if(it == fa[u] || it == son[u]) continue; dfs2(it, it); } return ; } ll qry(ll u, ll v) { ll fu = top[u], fv = top[v], ans = 0; while(fu != fv) { if(dep[fu] >= dep[fv]) { ans += query(dfn[fu], dfn[u], 1, n, root); ans %= p; u = fa[fu], fu = top[u]; } else { ans += query(dfn[fv], dfn[v], 1, n, root); ans %= p; v = fa[fv], fv = top[v]; } } if(dep[u] >= dep[v]) ans += query(dfn[v], dfn[u], 1, n, root); else ans += query(dfn[u], dfn[v], 1, n, root); ans %= p; return ans; } void upd(ll u, ll v, ll x) { ll fu = top[u], fv = top[v], ans = 0; while(fu != fv) { if(dep[fu] >= dep[fv]) { update(dfn[fu], dfn[u], x, 1, n, root); u = fa[fu], fu = top[u]; } else { update(dfn[fv], dfn[v], x, 1, n, root); v = fa[fv], fv = top[v]; } } if(dep[u] >= dep[v]) update(dfn[v], dfn[u], x, 1, n, root); else update(dfn[u], dfn[v], x, 1, n, root); return ; } int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cin >> n >> m >> r >> p; for(int i = 1; i <= n; i++) cin >> w[i]; for(int i = 1; i < n; i++) { ll u, v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } dfs1(r, r); dfs2(r, r); for(int i = 1; i <= n; i++) update(dfn[i], dfn[i], w[i], 1, n, root); while(m--) { ll op, u, v, x; cin >> op; if(op == 1) { cin >> u >> v >> x; upd(u, v, x); } if(op == 2) { cin >> u >> v; cout << qry(u, v) << "\n"; } if(op == 3) { cin >> u >> x; update(dfn[u], dfn[u] + siz[u] - 1, x, 1, n, root); } if(op == 4) { cin >> u; cout << query(dfn[u], dfn[u] + siz[u] - 1, 1, n, root) << "\n"; } } // cout << "\n\n----------------------\n\n"; // cout << "fa dep siz son top dfn rk\n\n"; // for(int i = 1; i <= n; i++) // { // cout << i << " : " << fa[i] << " " << dep[i] << " " << siz[i] << " " << son[i] << " " << top[i] << " " << dfn[i] << " " << rk[i] << "\n"; // } return 0; } ```