树链剖分 学习报告
Minakami_Yuki · · 题解
本题作为最终例题,代码在最下方给出。
众所周知,树链剖分是一个非常强的OIer(大雾
一些定义
重儿子:子树最大的儿子节点
轻儿子:其他儿子节点
重链:由重儿子组成的链
轻链:其他链
原理及实现
剖分结果
首先这是一颗树
先不谈原理,直接将剖分后的树拿出来
其中加重的边构成重链,重儿子被染成红色,轻儿子被染成蓝色。
注意11与12号节点,他们的子树大小相同,此时我们仅将节点编号小的作为重儿子。
对着上面的两幅图可以更好的理解“重”的概念。
剖分过程
我们通过两次dfs进行剖分。
规定变量
/*默认邻接表存边*/
sz[x] // 以x节点为根的子树大小
d[x] // x节点的深度
fa[x] // x节点的父节点
son[x] // x节点的重儿子(注意是重儿子)
tp[x] // x节点所在链的链顶
开始剖分
第一遍dfs处理出每个节点的深度和子树大小,从而处理出每个节点的重儿子。
void dfs1(int x) {
int maxx = -1; // 记录重儿子的子树大小以更新
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == fa[x]) continue;
d[y] = d[x] + 1;
fa[y] = x;
sz[x]++;
dfs1(y);
sz[x] += sz[y];
if(sz[y] > maxx) {
maxx = sz[y];
son[x] = y;
}
}
}
第二遍dfs处理出每个节点所在链的链顶,注意这里的链仅指重链,规定轻链上每个点的链顶都为自己。
void dfs2(int x, int top) {
tp[x] = top;
//这里要优先遍历重儿子
if(son[x]) dfs2(son[x], top);
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == son[x] || y == fa[x]) continue;
dfs2(y, y);
}
}
至此,我们已经完成了树链剖分。
处理问题
那么,树剖可以处理哪些问题呢?
在线LCA
考虑每条重链,根据定义,我们发现每个点若链顶相同,则LCA为深度较小的点。
可以结合上面的图理解一下。
有了这个性质,我们不妨尝试将每个点都转移到链顶相同的点,于是就有了下面的代码。
inline int lca(int x, int y) {
int tp1 = tp[x], tp2 = tp[y];
while(tp1 != tp2) {
if(d[tp1] > d[tp2]) {
std::swap(tp1, tp2);
std::swap(x, y);
}
y = fa[tp2];
tp2 = tp[y];
}
return d[x] < d[y] ? x : y;
}
复杂度为
例题:P3379 【模板】最近公共祖先(LCA)
板子题,直接套上面的代码。
#include <cstdio>
#include <cctype>
namespace FastIO {
inline int read() {
char ch = getchar(); int r = 0, w = 1;
while(!isdigit(ch)) {if(ch == '-') w = -1; ch = getchar();}
while(isdigit(ch)) {r = r * 10 + ch - '0', ch = getchar();}
return r * w;
}
void _write(int x) {
if(x < 0) putchar('-'), x = -x;
if(x > 9) _write(x / 10);
putchar(x % 10 + '0');
}
inline void write(int x) {
_write(x);
puts("");
}
}
using namespace FastIO;
template <typename T> inline void swap(T &x, T &y) {T tmp = x; x = y, y = tmp;}
const int N = 5e5 + 6;
const int M = N << 1;
int head[N], nxt[M], ver[M], cnt;
int tp[N], sz[N], fa[N], son[N];
int d[N];
inline void add(int x, int y) {
ver[++cnt] = y, nxt[cnt] = head[x], head[x] = cnt;
}
void dfs1(int x) {
int maxx = -1;
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == fa[x]) continue;
d[y] = d[x] + 1;
sz[x]++;
fa[y] = x;
dfs1(y);
sz[x] += sz[y];
if(sz[y] > maxx) {
maxx = sz[y];
son[x] = y;
}
}
}
void dfs2(int x, int top) {
tp[x] = top;
if(son[x]) dfs2(son[x], top);
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == son[x] || y == fa[x]) continue;
dfs2(y, y);
}
}
inline int lca(int x, int y) {
int tp1 = tp[x], tp2 = tp[y];
while(tp1 != tp2) {
if(d[tp1] > d[tp2]) {
swap(tp1, tp2);
swap(x, y);
}
y = fa[tp2];
tp2 = tp[y];
}
return d[x] < d[y] ? x : y;
}
int main() {
int n = read(), m = read(), s = read();
for(register int i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y);
add(y, x);
}
dfs1(s);
dfs2(s, s);
for(register int i = 1; i <= m; i++) {
write(lca(read(), read()));
}
return 0;
}
处理子树
考虑一棵树的dfs序
dfs序
用数组dfn[x]表示节点x的dfs序,即对根节点dfs时访问到它的顺序。
这个可以在dfs2()时维护:
void dfs2(int x, int top) {
tp[x] = top;
dfn[x] = ++idx;
if(son[x]) dfs2(son[x], top);
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == son[x] || y == fa[x]) continue;
dfs2(y, y);
}
}
依然是剖分后的树,我们标记一下每个节点的dfs序:
可以发现,每个子树中的节点的dfs序是连续的。
我们用区间来表示一颗子树:
清楚多了。
可以发现每个节点区间的左端点为dfn[x],右端点为dfn[x] + sz[x] - 1
这让我们联想到线段树
于是我们可以对整棵树建立一颗线段树,这样可以区间查询获得子树信息。
例题:P3384 【模板】树链剖分
链的dfs序显然是连续的,同样放在线段树上就行了。
注意本题中子树大小将根节点计算入内,sz[x] 应初始化为1。
#include <cstdio>
#include <cctype>
#define LL long long
namespace FastIO {
inline int read() {
char ch = getchar(); int r = 0, w = 1;
while(!isdigit(ch)) {if(ch == '-') w = -1; ch = getchar();}
while(isdigit(ch)) {r = r * 10 + ch - '0', ch = getchar();}
return r * w;
}
void _write(int x) {
if(x < 0) putchar('-'), x = -x;
if(x > 9) _write(x / 10);
putchar(x % 10 + '0');
}
inline void write(int x) {
_write(x);
puts("");
}
}
using namespace FastIO;
template <typename T> inline void swap(T &x, T &y) {T tmp = x; x = y, y = tmp;}
const int N = 1e5 + 6;
const int M = N << 1;
int n, m, r, p, ptn[N], w[N];
// 线段树
#define lson (o << 1)
#define rson (o << 1) | 1
#define mid ((l + r) >> 1)
LL val[N << 2], tag[N << 2];
inline void pushdown(int o, int l, int r) {
if(tag[o]) {
tag[lson] = (tag[lson] + tag[o]) % p;
tag[rson] = (tag[rson] + tag[o]) % p;
val[lson] = (val[lson] + tag[o] * (mid - l + 1) % p) % p;
val[rson] = (val[rson] + tag[o] * (r - mid) % p) % p;
tag[o] = 0;
}
}
inline void upd(int o) {
val[o] = (val[lson] + val[rson]) % p;
}
inline void build(int o, int l, int r) {
if(l == r) {
val[o] = w[ptn[l]];
return;
}
build(lson, l, mid);
build(rson, mid + 1, r);
upd(o);
}
inline void add(int o, int l, int r, int ll, int rr, LL v) {
if(l >= ll && r <= rr) {
val[o] = (val[o] + v * (r - l + 1) % p) % p;
tag[o] = (tag[o] + v) % p;
return;
}
pushdown(o, l, r);
if(ll <= mid) add(lson, l, mid, ll, rr, v);
if(rr > mid) add(rson, mid + 1, r, ll, rr, v);
upd(o);
}
inline LL squery(int o, int l, int r, int ll, int rr) {
if(l >= ll && r <= rr) {
return val[o];
}
pushdown(o, l, r);
LL res = 0;
if(ll <= mid) res = (res + squery(lson, l, mid, ll, rr)) % p;
if(rr > mid) res = (res + squery(rson, mid + 1, r, ll, rr)) % p;
return res;
}
// 树剖
int head[N], nxt[M], ver[M], cnt;
int fa[N], sz[N], d[N], tp[N], son[N], dfn[N], idx;
inline void adde(int x, int y) {
ver[++cnt] = y, nxt[cnt] = head[x], head[x] = cnt;
}
void dfs1(int x) {
sz[x] = 1;
int maxx = -1;
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == fa[x]) continue;
d[y] = d[x] + 1;
fa[y] = x;
dfs1(y);
sz[x] += sz[y];
if(sz[y] > maxx) {
maxx = sz[y];
son[x] = y;
}
}
}
void dfs2(int x, int top) {
dfn[x] = ++idx;
ptn[dfn[x]] = x;
tp[x] = top;
if(son[x]) dfs2(son[x], top);
for(register int i = head[x]; i; i = nxt[i]) {
int y = ver[i];
if(y == son[x] || y == fa[x]) continue;
dfs2(y, y);
}
}
inline void Addline(int x, int y, LL z) {
int tx = tp[x], ty = tp[y];
while(tx != ty) {
if(d[tx] > d[ty]) {
swap(tx, ty);
swap(x, y);
}
add(1, 1, n, dfn[ty], dfn[y], z);
y = fa[ty];
ty = tp[y];
}
if(d[x] > d[y]) swap(x, y);
add(1, 1, n, dfn[x], dfn[y], z);
}
inline LL Queryline(int x, int y) {
int tx = tp[x], ty = tp[y];
LL res = 0;
while(tx != ty) {
if(d[tx] > d[ty]) {
swap(tx, ty);
swap(x, y);
}
res = (res + squery(1, 1, n, dfn[ty], dfn[y])) % p;
y = fa[ty];
ty = tp[y];
}
if(d[x] > d[y]) swap(x, y);
res = (res + squery(1, 1, n, dfn[x], dfn[y])) % p;
return res;
}
inline void Addt(int x, LL z) {
add(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, z);
}
inline LL Queryt(int x) {
return squery(1, 1, n, dfn[x], dfn[x] + sz[x] - 1);
}
int op, x, y, z;
int main() {
n = read(), m = read(), r = read(), p = read();
for(register int i = 1; i <= n; i++) w[i] = read();
for(register int i = 1; i < n; i++) {
int x = read(), y = read();
adde(x, y);
adde(y, x);
}
dfs1(r);
dfs2(r, r);
build(1, 1, n);
while(m--) {
op = read();
x = read();
if(op <= 2) y = read();
if(op == 1 || op == 3) z = read();
switch (op) {
case 1:
Addline(x, y, z);
break;
case 2:
write(Queryline(x, y));
break;
case 3:
Addt(x, z);
break;
case 4:
write(Queryt(x));
break;
}
}
return 0;
}