链分治
分成三(四?)个部分。
\bf{0}. 前言
最近学链分治,感觉难度比较大,想来想去查阅了不少资料搞懂了。想写一篇笔记。
造福后人谈不上吧,就讲讲心得喵。
谨此,与诸君共勉。
\bf {1}. 重链剖分
约定
- 重儿子:儿子中子树大小最大的所在儿子。
- 轻儿子:不是重儿子的儿子。
- 重边:连向重儿子的边称为重边。
- 轻边:不是重边的边。
- 重链:全是重边祖先 - 儿子链。
贴一张图:
思想
- 引理
1 :u 子树所有结点的 dfs 序连续,具体的,为[{dfn}_u, {dfn}_u + {siz}_u - 1] 。
:::info[proof]
简单结论不证了吧。=)
感性理解。
:::
- 引理
2 :一条路径中的重链条数(也即轻边条数)是\mathcal{O}(\log n) 量级的。
:::info[proof]{open}
一条路径可以拆成
我们假设祖先链上一条
那么每经过一条轻边子树大小至少会变为
:::
先来构想一个问题吧!你需要维护一个树上数据结构,支持树链修改、查询与子树修改、查询。
看看子树修改、查询,放到 dfs 序上,这就是一个区修区查,可以用某数据结构。
再来看树链修改、查询,根据引理
dsu on tree
树上启发式合并。同样,来看一道题。
暴力做法是枚举每一个点并枚举其子树计算占据主导地位的颜色。
我们发现这里有很多部分都被重复计算了,故考虑在做的时候(也可以理解为合并的时候)保留一部分计算答案,即启发式合并。
对于计算结点
- 计算
u 的所有轻儿子的答案,并清空计算过程的变量。 - 计算
u 的重儿子的答案,并不清空过程中的变量。 - 再次遍历所有
u 的轻子树,统计答案并合并到尚未清空的重子树答案。
正确性比较好想,最后保留的恰好不重不漏的是
:::info[proof]{open} 其实可以做一个小小的转化。
观察到,每一个点只会被祖先链上的轻边的父亲方向结点计算一次,转化为祖先链轻边数量。
而根据引理
例题
:::success[P3384 【模板】重链剖分 / 树链剖分]
模板。远古构式码风,见谅。
#include <bits/stdc++.h>
#define ios \
ios::sync_with_stdio(0); \
cin.tie(0); \
cout.tie(0)
// #pragma GCC optimize(2)
#define int long long
#define pll pair<ll, ll>
#define pii pair<int, int>
#define il inline
#define p_q priority_queue
#define u_m unordered_map
#define r_g register
#define NR Nothing::read
#define NW Nothing::write
#define endl '\n'
using namespace std;
namespace Nothing {} // namespace Nothing
const int Maxn = 1000005;
int a[Maxn], tmp[Maxn];
int p;
struct SegmentTree {
#define ls (id << 1)
#define rs (id << 1 | 1)
struct Segment {
int Left;
int Right;
int valMax;
int tag;
int valSum;
} seg[Maxn << 2];
il void PushUp(int id) {
seg[id].valMax = max(seg[ls].valMax, seg[rs].valMax) % p;
seg[id].valSum = (seg[ls].valSum + seg[rs].valSum) % p;
return;
}
il void PushDown(int id) {
if (seg[id].tag) {
seg[ls].tag += seg[id].tag;
seg[ls].tag %= p;
seg[ls].valSum += seg[id].tag * (seg[ls].Right - seg[ls].Left + 1);
seg[ls].valSum %= p;
seg[rs].tag += seg[id].tag;
seg[rs].tag %= p;
seg[rs].valSum += seg[id].tag * (seg[rs].Right - seg[rs].Left + 1);
seg[rs].valSum %= p;
seg[id].tag = 0;
}
return;
}
il void Build(int id, int Left, int Right) {
seg[id] = {Left, Right, 0, 0, 0};
if (Left == Right) {
seg[id].valMax = a[Left] % p;
seg[id].valSum = a[Left] % p;
return;
}
int mid = (Left + Right) >> 1;
Build(ls, Left, mid);
Build(rs, mid + 1, Right);
PushUp(id);
return;
}
il int QuerySum(int id, int Left, int Right) {
PushDown(id);
if (seg[id].Right < Left || seg[id].Left > Right) {
return 0;
}
if (Left <= seg[id].Left && seg[id].Right <= Right) {
return seg[id].valSum % p;
}
return (QuerySum(ls, Left, Right) + QuerySum(rs, Left, Right)) % p;
}
il void Change(int id, int Left, int Right, int val) {
PushDown(id);
if (seg[id].Right < Left || seg[id].Left > Right) {
return;
}
if (seg[id].Left >= Left && Right >= seg[id].Right) {
seg[id].tag += val;
seg[id].tag %= p;
seg[id].valSum += val * (seg[id].Right - seg[id].Left + 1) % p;
seg[id].valSum %= p;
return;
}
Change(ls, Left, Right, val);
Change(rs, Left, Right, val);
PushUp(id);
return;
}
};
vector<int> G[Maxn];
int n, m, root;
struct Qtree {
struct treeNode {
int fa;
int son;
int dep;
int size;
int top;
int tid;
} tn[Maxn];
SegmentTree SEG;
int tot = 0;
void dfs1(int step, int fa) {
tn[step].fa = fa;
tn[step].dep = tn[fa].dep + 1;
tn[step].size = 1;
int Max = 0;
for (auto x : G[step]) {
if (x == fa) {
continue;
}
dfs1(x, step);
tn[step].size += tn[x].size;
if (tn[x].size > Max) {
Max = tn[x].size;
tn[step].son = x;
}
}
return;
}
void dfs2(int step, int top) {
tn[step].top = top;
tn[step].tid = ++tot;
a[tot] = tmp[step];
if (tn[step].son)
dfs2(tn[step].son, top);
for (auto x : G[step]) {
if (x == tn[step].fa || x == tn[step].son) {
continue;
}
dfs2(x, x);
}
}
void Build() {
dfs1(root, 0);
dfs2(root, root);
SEG.Build(1, 1, n);
return;
}
void Change(int u, int v, int w) {
while (tn[u].top != tn[v].top) {
if (tn[tn[u].top].dep < tn[tn[v].top].dep) { // 每次只跳链顶深度更低的,以免跳错(类似爬树法)
swap(u, v);
}
SEG.Change(1, tn[tn[u].top].tid, tn[u].tid, w % p); // 操作
u = tn[tn[u].top].fa; // 跳
}
if (tn[u].tid > tn[v].tid) {
swap(u, v);
}
SEG.Change(1, tn[u].tid, tn[v].tid, w % p); // 位于同一条重链上
return;
}
int QuerySum(int u, int v) {
int Max = 0;
while (tn[u].top != tn[v].top) {
if (tn[tn[u].top].dep < tn[tn[v].top].dep) {
swap(u, v);
}
Max += SEG.QuerySum(1, tn[tn[u].top].tid, tn[u].tid);
Max %= p;
u = tn[tn[u].top].fa;
}
if (tn[u].tid > tn[v].tid) {
swap(u, v);
}
return Max + SEG.QuerySum(1, tn[u].tid, tn[v].tid);
}
} Qt;
signed main() {
ios;
cin >> n >> m >> root >> p;
for (int i = 1; i <= n; i++) {
cin >> tmp[i];
tmp[i] %= p;
}
for (int i = 1; i < n; i++) {
int x, y;
cin >> x >> y;
G[x].push_back(y);
G[y].push_back(x);
}
Qt.Build();
while (m--) {
int Type;
cin >> Type;
if (Type == 1) {
int x, y, z;
cin >> x >> y >> z;
Qt.Change(x, y, z);
} else if (Type == 2) {
int x, y;
cin >> x >> y;
cout << Qt.QuerySum(x, y) % p << endl;
} else if (Type == 3) {
int x, z;
cin >> x >> z;
z %= p;
Qt.SEG.Change(1, Qt.tn[x].tid, Qt.tn[x].tid + Qt.tn[x].size - 1, z);
} else {
int x;
cin >> x;
cout << Qt.SEG.QuerySum(1, Qt.tn[x].tid,
Qt.tn[x].tid + Qt.tn[x].size - 1) %
p
<< endl;
}
}
return 0;
}
:::
:::success[SP6779 GSS7 - Can you answer these queries VII]
题目大意:树链覆盖,树链查最大子段和。
直接用内层线段树维护最大子段和,类似 GSS1,注意一下顺序。
#include <bits/stdc++.h>
#define int long long
#define pii pair<int, int>
#define inf 0x3f3f3f3f3f3f3f3f
#define F(x, v) for (auto x : (v))
#define ALL(x) (x).begin(), (x).end()
#define L(i, a, b) for (register int i = (a); i <= (b); i++)
#define R(i, a, b) for (register int i = (a); i >= (b); i--)
#define FRE(x) freopen(x ".in", "r", stdin), freopen(x ".out", "w", stdout)
using namespace std;
inline int cmax(int& x, int c) { return x = max(x, c); }
inline int cmin(int& x, int c) { return x = min(x, c); }
bool bgmem;
int _test_ = 1, cas;
namespace zrh {
const int N = 1e5 + 5;
int n, q, a[N], b[N], dfn[N], tot;
vector< int > g[N];
// segment tree
#define ls (u << 1)
#define rs (u << 1 | 1)
struct V {
int lsu, rsu, msu, sum;
V() { lsu = rsu = msu = sum = 0; }
V(int a, int b, int c, int d) { lsu = a, rsu = b, msu = c, sum = d; }
} tr[N << 2]; int cov[N << 2];
V operator+(V x, V y) { V ret;
ret.lsu = max(x.lsu, x.sum + y.lsu);
ret.rsu = max(y.rsu, y.sum + x.rsu);
ret.msu = max(max(x.msu, y.msu), x.rsu + y.lsu);
ret.sum = x.sum + y.sum;
return ret;
}
void up(int u) { tr[u] = tr[ls] + tr[rs]; }
void modi(int u, int L, int R, int v) {
tr[u].lsu = tr[u].rsu = tr[u].msu = max(0ll, v * (R - L + 1));
tr[u].sum = v * (R - L + 1), cov[u] = v;
}
void down(int u, int L, int R) {
if (cov[u] == inf) return;
int mid = (L + R) >> 1;
modi(ls, L, mid, cov[u]), modi(rs, mid + 1, R, cov[u]), cov[u] = inf;
}
void bui(int u, int L, int R) {
cov[u] = inf;
if (L == R) return modi(u, L, R, b[L]);
int mid = (L + R) >> 1;
bui(ls, L, mid), bui(rs, mid + 1, R), up(u);
}
void cha(int u, int L, int R, int l, int r, int c) {
if (L > r || R < l) return;
if (l <= L && R <= r) return modi(u, L, R, c);
int mid = (L + R) >> 1; down(u, L, R);
cha(ls, L, mid, l, r, c), cha(rs, mid + 1, R, l, r, c), up(u);
}
V que(int u, int L, int R, int l, int r) {
if (l <= L && R <= r) return tr[u];
int mid = (L + R) >> 1; down(u, L, R);
if (r <= mid) return que(ls, L, mid, l, r);
if (l > mid) return que(rs, mid + 1, R, l, r);
return que(ls, L, mid, l, mid) + que(rs, mid + 1, R, mid + 1, r);
}
// tree divide
int dep[N], FA[N], top[N], hea[N], siz[N];
void dfs0(int u, int fa) {
dep[u] = dep[FA[u] = fa] + 1, siz[u] = 1;
int mx = 0;
F(v, g[u]) if (v != fa) {
dfs0(v, u), siz[u] += siz[v];
if (mx < siz[v]) mx = siz[v], hea[u] = v;
}
}
void dfs1(int u, int tp) {
top[u] = tp, dfn[u] = ++tot, b[tot] = a[u];
if (hea[u]) dfs1(hea[u], tp);
F(v, g[u]) if (v != FA[u] && v != hea[u]) dfs1(v, v);
}
void tdc(int u, int v, int w) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
cha(1, 1, n, dfn[top[u]], dfn[u], w), u = FA[top[u]];
}
cha(1, 1, n, min(dfn[u], dfn[v]), max(dfn[u], dfn[v]), w);
}
int tdq(int u, int v) {
V anl, anr;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u, v);
swap(anl, anr);
}
V p = que(1, 1, n, dfn[top[u]], dfn[u]);
anl = p + anl;
u = FA[top[u]];
}
if (dep[u] > dep[v]) {
swap(u, v);
swap(anl, anr);
}
swap(anl.lsu, anl.rsu);
V p = que(1, 1, n, dfn[u], dfn[v]);
anl = anl + p;
return (anl + anr).msu;
}
void init() {}
void clear() {}
void solve() {
cin >> n; L(i, 1, n) cin >> a[i];
L(i, 1, n - 1) {
int u, v; cin >> u >> v;
g[u].push_back(v), g[v].push_back(u);
}
dfs0(1, 0), dfs1(1, 1), bui(1, 1, n);
cin >> q; L(_, 1, q) {
int op, u, v, w; cin >> op >> u >> v;
if (op == 1) cout << tdq(u, v) << "\n";
else cin >> w, tdc(u, v, w);
}
}
} // namespace zrh
bool edmem;
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
// cin >> _test_;
zrh::init();
while (++cas <= _test_) zrh::clear(), zrh::solve();
cerr << "memory: " << fabs(&edmem - &bgmem) / 1024 / 1024 << "MB\n";
cerr << "time : " << (double)clock() * CLOCKS_PER_SEC / 1000 << "ms\n";
return 0;
}
:::
:::success[一道不知道来源的题目]
题目大意:一棵有根树,以
1 为根。每个点只有一个颜色,请你求出每个点的子树的出现次数最多的颜色。有多个颜色出现次数相同时,输出最小的解。
直接上 dsu on tree,注意遍历顺序。
#include <bits/stdc++.h>
// #define fastio
#define int long long
#define bw(x) (1 << (x))
#define eb emplace_back
#define pii pair<int, int>
#define inf (0x3f3f3f3f3f3f3f3f)
#define F(x, v) for (auto x : (v))
#define ALL(x) (x).begin(), (x).end()
#define L(i, a, b) for (register int i = (a); i <= (b); ++i)
#define R(i, a, b) for (register int i = (a); i >= (b); --i)
#define FRE(i, o) freopen(i, "r", stdin), freopen(o, "w", stdout)
#define debug(a) cerr << "\033[32m[DEBUG] " << #a << " = " << (a) << " at line " << __LINE__ << "\033[0m\n"
using namespace std; bool bgmem;
#ifdef fastio
struct IO {
#define ion bw(20)
char i[ion], o[ion], *icl = i, *icr = i, *oc = o;
char gc() { return (icl == icr && (icr = (icl = i) + fread(i, 1, ion, stdin), icl == icr)) ? EOF : *icl++; }
void pc(char c) { if (oc - o == ion) fwrite(o, 1, ion, stdout), oc = o; *oc++ = c; }
void rd(auto &x) { char c; int f = 1; x = 0; while (c = gc(), c < '0' || c > '9') if (c == '-') f = -1; while ('0' <= c && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = gc(); x *= f; }
void pr(auto x) { int a[64], p = 0; if (x < 0) x = -x, pc('-'); do a[p++] = x % 10; while (x /= 10); while (p--) pc(a[p] + '0'); }
IO& operator>>(char &c) { return c = gc(), *this; }
IO& operator<<(char c) { return pc(c), *this; }
IO& operator<<(const char *c) { while (*c) pc(*c++); return *this; }
IO& operator>>(auto &x) { return rd(x), *this; }
IO& operator<<(auto x) { return pr(x), *this; }
} io;
#define cin io
#define cout io
#endif
template < const int P > class modint { public:
int x;
modint(int _ = 0) { x = (_ % P + P) % P; }
explicit operator int() { return x; }
friend modint operator+(modint x, modint y) { return x.x + y.x < P ? x.x + y.x : x.x + y.x - P; }
friend modint operator-(modint x, modint y) { return x.x - y.x < 0 ? x.x - y.x + P : x.x - y.x; }
friend modint operator*(modint x, modint y) { return modint((__int128)(x.x * y.x)); }
modint& operator+=(modint x) { return (*this) = (*this) + x; }
modint& operator-=(modint x) { return (*this) = (*this) - x; }
modint& operator*=(modint x) { return (*this) = (*this) * x; }
};
inline int cmax(int& x, int c) { return x = max(x, c); }
inline int cmin(int& x, int c) { return x = min(x, c); }
int tes = 1, cas;
namespace qwq {
const int N = 3e5 + 5;
int n, a[N], tmp[N], dep[N], fa[N], hea[N], siz[N], cnt[N], ans, ret[N]; vector< int > g[N];
void dfs(int u, int f) {
fa[u] = f, dep[u] = dep[f] + 1, siz[u] = 1;
F(v, g[u]) if (v != f) {
dfs(v, u);
siz[u] += siz[v];
if (siz[hea[u]] < siz[v]) hea[u] = v;
}
} void sol1(int u, int w, int p = 0) {
cnt[a[u]] += w;
if (cnt[a[u]] > cnt[ans]) ans = a[u];
if (cnt[a[u]] == cnt[ans]) cmin(ans, a[u]);
F(v, g[u]) if (v != fa[u] && v != p) sol1(v, w, p);
} void sol2(int u, bool fl) {
F(v, g[u]) if (v != fa[u] && v != hea[u]) sol2(v, 0);
if (hea[u]) sol2(hea[u], 1), sol1(u, 1, hea[u]);
else sol1(u, 1);
ret[u] = tmp[ans];
if (!fl) sol1(u, -1), ans = 0;
}
void init() {}
void clear() {}
void solve() {
cin >> n; L(i, 1, n) cin >> a[i], tmp[i] = a[i];
sort(tmp + 1, tmp + n + 1);
int cnt = unique(tmp + 1, tmp + n + 1) - tmp - 1;
L(i, 1, n) a[i] = lower_bound(tmp + 1, tmp + cnt + 1, a[i]) - tmp;
L(i, 1, n - 1) {
int u, v; cin >> u >> v;
g[u].eb(v), g[v].eb(u);
}
dfs(1, 0), sol2(1, 1);
L(i, 1, n) cout << ret[i] << "\n";
}
} bool edmem; signed main() {
// FRE("", "");
#ifndef fastio
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
#endif
// cin >> tes;
qwq::init(); while (++cas <= tes) qwq::clear(), qwq::solve();
#ifdef fastio
fwrite(io.o, 1, io.oc - io.o, stdout);
#endif
cerr << "time : " << (double)clock() / CLOCKS_PER_SEC * 1000 << "ms\n";
cerr << "memory: " << fabs(&edmem - &bgmem) / 1024 / 1024 << "mb\n";
return 0;
}
// 其实是不能说来源。qwq
:::
优缺点
优
- 祖先链重链条数优秀,
\mathcal{O}(\log n) 的量级。对应的可以快速维护树上数据结构题。 - 带一个
\frac{1}{4} 倍常数,极小。不知道咋出来的。^1
缺
- 你说得对但是加上线段树就是
\mathcal{O}(n \log ^ 2 n) 了,还没 LCT 优秀(不过实际跑起来快多了就是了)。
\bf{2.} 长链剖分
约定
类似重链剖分,定义长儿子为结点儿子中子树深度最大的那一个。
思想
- 引理
3 :一个结点祖先中长链条数(也即短边数量)是\mathcal{O}(\sqrt{n}) 量级的。
:::info[proof]{open}
考虑构造一种最坏情况,也即每个短儿子的长链长度都比父亲恰好大
显然此时祖先链长链条数最多的是图中右下角的点,设其为
满足
- 引理
4 :u 的k 级祖先所在长链的长度不小于k 。
:::info[proof]{}
设