树分治

· · 算法·理论

Empty

本文同步发表在 博客园

静态树分治

这个算法可以处理关于树上所有链信息的问题,由于枚举所有链的复杂度绝对是 O(n^2)。所以需要更优秀的算法,可以通过合并两条链的方式枚举所有链。

树分治是一个暴力数据结构。

它可以在 O(n \log n) 的时间里遍历所有的链。 下面的算法可以处理树形态不变的情况下的问题。

边分治

找一条边,把树分成两半。在两边子树找两点,则有一条过这条边的路径。

再对两个子树分治。

可是菊花图会被卡成 n^ 2

使用 “左二子,右兄弟” 的方法可以优化。

可是也有局限性(我并不是太了解边分治)。

点分治。

相比于边分治,点分治更通用一些,后面的点分树也是由此转化得来。

我们每次找到一个点 u,处理经过 u 的点的答案,处理子树之间的路径。

删去 u

再对 u 的子树分治。

这样可以 “遍历” 所有路径。

每次分治的中心选取 重心 是最优的(每次子树大小除以 2)。

于是时间复杂度是 O(n \log n) 的。

所以我们可以每次遍历所有分治中心的子树,暴力计算答案

一般有两种方式计算答案。

  1. 容斥。用所有点两两的答案减去子树内的答案。

  2. 数据结构维护。每次用遍历当前子树,用数据结构查询之前子树的答案,在把当前字树加入数据结构。

代码实例:

int sze[400010], dp[400010], vis[400010]; // vis:是否当做过分治中心
int tot, rt;
void getrt(int u, int fa) {
    sze[u] = 1, dp[u] = 0;
    for (auto y : edge[u]) {
        int v = y.first;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}

void solve(int u) {
    vis[u] = 1;
    ans += getans(u); // 以容斥计算答案为例
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        ans -= getans(v); // 
        sum = sze[v];
        dp[0] = n, rt = 0;
        getrt(v, u);
        solve(rt);
    }
}

好了,至此我们已经学会了点分治,下面要开始实战了。

先来一道最基础的题:

例1.1 Tree

给定大小为 n带权 树。

求长度小于等于 k 的路径条数。

**Sol:** 设当前的分治中心为 $u$。 答案要加上:$\sum_{subtree(x) \neq subtree(y)} [dis(x) + dis(y) \le k]

考虑把和式 x, y 的限制去掉。

它等于所有的 x, y,减去 x, y 在同一子树的答案。

\sum_{x, y}[dis(x) + dis(y) \le k] - \sum_{subtree(x) = subtree(y)} [dis(x) + dis(y) \le k]

可以用容斥算。

对于 \sum_{x, y}[dis(x) + dis(y) \le k] 这样的式子。可以用双指针维护。

看代码吧:

void getrt(int u, int fa) {
    sze[u] = 1, dp[u] = 0;
    for (auto y : edge[u]) {
        int v = y.first;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}
void getdis(int u, int fa) {
    rev[++tot] = d[u];
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (!vis[v] && v != fa) {
            d[v] = d[u] + w;
            getdis(v, u);
        }
    }
}// 算距离
int getans(int u, int w) {
    tot = 0, d[u] = w;
    getdis(u, 0);
    sort(rev + 1, rev + tot + 1);
    int l = 1, r = tot, tmp = 0;
    while (l <= r) {
        if (rev[l] + rev[r] <= k)
            tmp += r - l, ++l;
        else
            r--;
    }
    return tmp;
}// 得到 rev 数组里两两的答案
void solve(int u) {
    vis[u] = 1;
    ans += getans(u, 0);// 加上所有的点的答案。
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        ans -= getans(v, w); // 减去子树内的答案,得到两两子树的答案
        sum = sze[v];
        dp[0] = n, rt = 0;
        getrt(v, u);
        solve(rt);
    }
}

练1.1.1 聪聪可可

P2634 [国家集训队] 聪聪可可

题意:给你一颗 带权 树,求有多少条路径满足长度是 3 的倍数。

sol:这算是刚才那道题改了一下(更简单了)。

还是利用容斥。

一样的,答案为:\sum_{x, y}[3 | dis(x) + dis(y) ] - \sum_{subtree(x) = subtree(y)} [3 | dis(x) + dis(y)]

容易计算。

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 20010;
#define int long long
vector<pair<int, int> > edge[N];
int n;
int sze[N], dp[N];
int vis[N];
int d[N];
int tot, rt, sum;
int rev[N];
int ans;
void getrt(int u, int fa) {
    sze[u] = 1, dp[u] = 0;
    for (auto y : edge[u]) {
        int v = y.first;
        if (vis[v] || v == fa)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}
void getdis(int u, int fa) {
    rev[++tot] = d[u];
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (!vis[v] && v != fa) {
            d[v] = d[u] + w;
            getdis(v, u);
        }
    }
}
int getans(int u, int w) {
    tot = 0, d[u] = w;
    getdis(u, 0);
    int s[3] = { 0, 0, 0 };
    for (int i = 1; i <= tot; i++) {
        s[rev[i] % 3]++;
    }

    return s[0] * s[0] + s[1] * s[2] * 2;
}
void solve(int u) {
    vis[u] = 1;
    ans += getans(u, 0);
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        ans -= getans(v, w);
        sum = sze[v];
        dp[0] = n, rt = 0;
        getrt(v, u);
        solve(rt);
    }
}

int gcd(int a, int b) { return (b == 0) ? a : gcd(b, a % b); }

signed main() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        edge[u].push_back({ v, w });
        edge[v].push_back({ u, w });
    }
    dp[0] = sum = n;

    getrt(1, 0);
    solve(rt);
    cout << ans / gcd(ans, n * n) << "/" << n * n / gcd(ans, n * n);
    return 0;
}

练1.1.2 重建计划

(P4292 [WC2010] 重建计划)[https://www.luogu.com.cn/problem/P4292]

题意:

给你一棵带权树:求一条边数在 [L, R] 路径,使权值和除以边数最大,

设边权为 w_1, w_2, \cdots, w_k

答案为:\frac{w_1 + w_2 + \cdots w_k}{k}

二分答案 x,要使得:

\frac{w_1 + w_2 + \cdots w_k}{k} \ge x

{w_1 + w_2 + \cdots w_k} \ge kx

{(w_1 - x) + (w_2 - x) + (w_3 - x) \cdots (w_k - x)} \ge 0

每次二分先把边权减去 x

我们要求最长的一条路径的长度 \ge 0

还是点分治。每次用线段树维护边数在一段区间的答案,可是这样做的时间复杂度为 O(n \log^3 n)

注意到每次查询的区间为 [L - dep, R - dep]

按照子树内 dep 从大到小排序的话,可以用单调队列维护,查询完答案后与原数组的值取 \max

这样做还会有问题,单调队列的大小由最深的子树决定。最坏复杂度为 n^2 \log n

所以我们得按照子树深度从小到大的顺序枚举。

时间复杂度 O(n \log^2 n)

代码: 这里放的是同学的代码,因为我是用 n \log^3 n 的算法卡过去的。 见谅。

#include <bits/stdc++.h>
using namespace std;

typedef int ll;
typedef double ld;

const ll Pig = 2e5 + 10;

ll n, L, R, dis[Pig], cnt[Pig], ans, sze[Pig], p[Pig], cur, ln[Pig], len;
ld d[Pig], buc[Pig], val;
vector<pair<ll, ll> > g[Pig];
vector<ll> pnt;
bitset<Pig> vis;

void dfs_ln(ll i, ll f, ll t) {
    dis[i] = dis[f] + 1;
    ln[t] = max(ln[t], dis[i]);

    for (auto j : g[i]) {
        if (j.first == f or vis[j.first])
            continue;

        dfs_ln(j.first, i, t);
    }
}

void dfs1(ll i, ll f) {
    sze[i] = 1;
    p[i] = 0;

    for (auto j : g[i]) {
        if (j.first == f or vis[j.first])
            continue;

        ll v = j.first, w = j.second;
        d[v] = d[i] + w - val;
        dis[v] = dis[i] + 1;
        dfs1(v, i);
        sze[i] += sze[v];
        p[i] = max(p[i], sze[v]);
    }
}

void dfs2(ll i) {
    if (ans)
        return;

    vector<ll> v, curr;
    vector<pair<ll, ll> > gg;
    buc[0] = 0;
    dis[i] = 1;
    vis[i] = 1;
    len = 0;

    for (auto j : g[i]) {
        if (!vis[j.first]) {
            ln[j.first] = 0;
            dfs_ln(j.first, i, j.first);
            gg.emplace_back(j);
        }
    }

    sort(gg.begin(), gg.end(), [&](pair<ll, ll> a, pair<ll, ll> b) { return ln[a.first] < ln[b.first]; });

    for (auto j : gg) {
        cur = j.first;
        d[cur] = j.second - val;
        dis[j.first] = 1;
        dfs1(cur, i);
        queue<ll> q;
        deque<ll> c;
        vector<ll> point;
        q.emplace(cur);
        ll r = -1;

        while (!q.empty()) {
            ll pt = q.front();
            p[pt] = max(p[pt], sze[j.first] - sze[pt]);

            if (p[pt] < p[cur])
                cur = pt;

            q.pop();
            point.emplace_back(pt);
            for (auto k : g[pt])
                if (dis[k.first] > dis[pt])
                    q.emplace(k.first);
        }

        reverse(point.begin(), point.end());

        for (ll k : point) {
            while (r < R - dis[k] and r < len) {
                r++;
                while (!c.empty() and buc[c.back()] < buc[r]) c.pop_back();
                c.emplace_back(r);
            }
            while (!c.empty() and c.front() + dis[k] < L) c.pop_front();
            if (!c.empty() and buc[c.front()] + d[k] >= 0)
                ans = 1;
        }

        for (ll k : point) buc[dis[k]] = max(buc[dis[k]], d[k]), curr.emplace_back(k), len = max(len, dis[k]);

        v.emplace_back(cur);

        if (ans) {
            vis[i] = 0;
            buc[0] = buc[Pig - 1];

            for (ll k : curr) buc[dis[k]] = buc[Pig - 1];

            return;
        }
    }

    buc[0] = buc[Pig - 1];

    for (ll k : curr) buc[dis[k]] = buc[Pig - 1];

    if (ans) {
        vis[i] = 0;
        return;
    }

    for (ll j : v) dfs2(j);

    vis[i] = 0;
}

ll read() {
    char c = getchar();
    ll res = 0;
    bool flg = 1;

    while (!isdigit(c)) {
        if (c == '-')
            flg = 0;
        c = getchar();
    }

    while (isdigit(c)) res = (res << 3) + (res << 1) + (c ^ '0'), c = getchar();

    if (!flg)
        res = -res;
    return res;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.setf(ios::fixed);
    cout.precision(3);
    memset(buc, -0x3f, sizeof(buc));
    n = read();
    L = read();
    R = read();
    bool flg = 1;

    for (ll i = 1, u, v, w; i < n; i++) {
        u = read();
        v = read();
        w = read();
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
    }

    ld l = 0, r = 1e6;

    while (fabs(l - r) > 4.5e-4) {
        val = (l + r) / 2;
        ans = 0;
        dfs2(1);

        if (ans)
            l = val;
        else
            r = val;
    }

    cout << (l + r) / 2;
    return 0;
}

练1.1.3 Yin and Yang G

P3085 [USACO13OPEN] Yin and Yang G

题意:给你一棵边权为 1 或者 -1 的树。求有多少条路径 u \to vu \ne v),满足路径上存在一点 pp\neq np \neq m)使得 \mathrm{dis(u, p)} = \mathrm{dis(p, m)} = 0

sol: 等价于求 \mathrm{dis(u, v)} = \mathrm{dis(u, p) = 0}

设分治中心为 r,子树内的点到它的距离记为 d_i

假设在它的子树中有两点 u, v,使得 d_u + d_v = 0,那是不是在 u 或者 v 的祖先中存在一个点 p 使得 dis_u = dis_p 或者 dis_v = dis_p 才行。即选的点 pdis 要和 u 或者 v 的值相等。

我们按照一个点 u 的祖先有没有点和 d_u 相等把点分成两类。

没有的点只能与有的点配成一对,而有的点可以和两种点配对。

这个用数组统计。需要注意的是要仔细考虑 ur 配对的情况。

代码: 为了避免负数,我将下标平移了一段区间,也可以使用 map 解决。

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 2e5 + 10;
int n;
vector<pair<int, int> > edge[N];
int dp[N], sze[N], vis[N], cnt[N], res[N], cnt1[N];
int rt, sum, ans;

void getrt(int u, int f) {
    dp[u] = 0, sze[u] = 1;
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v] || v == f)
            continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], sum - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}

int mn, mx;
bool flag = 0;
void getans(int u, int f, int dis) {
    ans += (res[dis + 100000] > 0) * cnt[-dis + 100000] + cnt1[-dis + 100000];
    if (dis == 0 && res[dis + 100000] > 1)
        ans++;
    res[dis + 100000]++;//统计当前点到 $r$ 的 $d$
    mn = min(mn, dis);
    mx = max(mx, dis);
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v] || v == f)
            continue;
        getans(v, u, dis + w);
    }
    res[dis + 100000]--;
}

void getdis(int u, int f, int dis) {
    cnt[dis + 100000] += (res[dis + 100000] == 0);
    cnt1[dis + 100000] += (res[dis + 100000] > 0);
    res[dis + 100000]++;//统计当前点到 $r$ 的 $d$

    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v] || v == f)
            continue;
        getdis(v, u, dis + w);
    }
    res[dis + 100000]--;
}

void solve(int u) {
    vis[u] = 1;
    res[100000] = 1;
    mn = 1e5, mx = -1e5;
    flag = 0;
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;
        getans(v, u, w);
        getdis(v, u, w);
    }
    for (int i = mn; i <= mx; i++) res[i + 100000] = cnt[i + 100000] = cnt1[i + 100000] = 0;
    for (auto y : edge[u]) {
        int v = y.first, w = y.second;
        if (vis[v])
            continue;

        sum = sze[v];
        rt = 0, dp[rt] = n + 1;
        getrt(v, u);
        solve(rt);
    }
}

signed main() {
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        if (w == 0)
            w = -1;
        edge[u].push_back({ v, w });
        edge[v].push_back({ u, w });
    }
    sum = n + 1;
    rt = 0, dp[rt] = n + 1;
    getrt(1, 0);
    solve(rt);
    cout << ans;
    return 0;
}

广义点分治

传统的点分治帮助我们快速统计树上所有路径。

可以发现分治的层数 不超过 \log n 层。

于是一类问题出现了,问一棵树上的最优点 x

怎么刻画最优是题目规定的,如带权重心(P3345 幻想乡战略游戏)。

我们先假设 f(x) 为这个点的答案。

如果这样的点只有一个,且满足单调性:若这个点的答案不优,则它的子树的答案一定更不优。

所以答案一定在满足 f(x) < f(fa_x) 的子树里。

我们可以从 fa_x 跳到 x 子树的重心。

则最优点一定在点分树的子子树里。

例1.2 快递员

题意: (不想改题面啦qwq)

Showson 的城市里面有 n 家快递站,被 n - 1 条带权无向边相连。

Showson 需要送 m 个快递,第 i 个货物需要从 u 送到 v。由于 Showson 不能带着货物走太长的路,所以对于一次送货,他需要先从集散中心到 u,再从 u 回到集散中心,再从集散中心到 v,最后从 v 返回集散中心。换句话说,如果设集散中心开在 c 号点,那么他的路径是 c \rightarrow u \rightarrow c \rightarrow v \rightarrow c

现在 Showson 希望确定一个点作为集散中心的开设位置,使得他送货所需的最长距离最小。显然,这个最长距离是个偶数,你只需要输出最长距离除以 2 的结果即可。

sol

点分治。

设现在的快递中心为 r

先暴力求出快递中心到所有点对的距离和。

找到所有使答案最大的点对。

r 在它们的路径上,则答案无法再小。

若两组点对所在的子树不同,答案也无法再小。

假如可以再小。

就往那个子树分治。

#include<bits/stdc++.h>
using namespace std;
//#define int long long

const int N = 1e5 + 10;
int n, m; 
vector<pair<int, int> > edge[N];

int vis[N], dp[N], sze[N], tot, rt;

struct node{
    int x, y;
}a[N];

void getrt(int u, int f) {
    dp[u] = 0, sze[u] = 1;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(vis[v] || v == f) continue;
        getrt(v, u);
        dp[u] = max(dp[u], sze[v]);
        sze[u] += sze[v];
    }
    dp[u] = max(dp[u], tot - sze[u]);
    if(dp[u] < dp[rt]) rt = u;
}
int d[N], sub[N], ans = 1e9;
void getdis(int u, int f, int s) {
    sub[u] = s;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(v == f) continue;
        d[v] = d[u] + w;
        getdis(v, u, s);
    }
}

void print() {
    cout << ans;
    exit(0);
}

int solve(int u) {
    if(vis[u]) print();
    vis[u] = 1;
    d[u] = 0;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        d[v] = w;
        getdis(v, u, v);
    }
    int mx = 0;
    vector<int> v;
    for(int i = 1; i <= m; i++) {
        int x = a[i].x, y = a[i].y;
        if(d[x] + d[y] > mx) v.clear(), v.push_back(i), mx = d[x] + d[y];
        else if(d[x] + d[y] == mx) v.push_back(i); 
    }
    if(mx < ans) ans = mx;
    int lst = 0;
    for(auto i : v) {
        int x = a[i].x, y = a[i].y;
        if(sub[x] != sub[y]) print(); 
        if(lst == 0) lst = sub[x];
        if(lst != sub[x]) print();
    }
    rt = 0, dp[0] = n + 1, tot = sze[lst];
    getrt(lst, u);
    solve(rt);
}

signed main() {
    cin >> n >> m;
    for(int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        edge[u].push_back({v, w});
        edge[v].push_back({u, w});
    }
    for(int i = 1; i <= m; i++) {
        cin >> a[i].x >> a[i].y;
    }
    rt = 0, dp[0] = n + 1, tot = n;
    getrt(1, 0);
    solve(rt);
    print();
    return 0;
}

练2.2.1 幻想乡战略游戏

题意

有一棵带权树。

找到一个点 u,使得 \sum_{v = 1}^{n} dis(u, v) \times a_v 最小。

Sol

这道题要用点分树。

请学习完点分树以后再来查看。

做法差不多。

S(x) 表示 x 点的子树里的点权和。

设当前最优 x 的答案为 resx 为根。

考略 $\Delta res_{x \to y} = res + (S(x) - 2S(y)) \times w(x, y)$。 $\Delta res_{x \to y} < 0 \Leftrightarrow S(x) < 2\times S(y)$。 这样的 $y$ 至多只有一个。 证明: 假设存在两个 $y1$, $y2$,使得 $S(x) < 2S(y1), S(x) < S(y2)$。 则 $2S(x) < 2(S(y1) + S(y2))$,所以 $S(y1) + S(y2) > S(x)$,不成立。 于是我们可以每次枚举点 $x$ 的所有儿子,找到最优的那个,对他进行点分治即可。 计算 $\sum_{v = 1}^{n} dis(u, v) \times a_v$ 直接点分树就行了。 ```cpp #include <bits/stdc++.h> using namespace std; #define int long long const int N = 1e5 + 10; struct node { int v, w, rt; }; int n; vector<node> edge[N]; int dep[N], top[N], Dis[N], sze[N], son[N], Fa[N]; void dfs1(int u, int f) { dep[u] = dep[f] + 1; Fa[u] = f; sze[u] = 1; for (auto y : edge[u]) { int v = y.v, w = y.w; if (v == f) continue; Dis[v] = Dis[u] + w; dfs1(v, u); sze[u] += sze[v]; if (sze[v] > sze[son[u]]) son[u] = v; } } void dfs2(int u, int tp) { top[u] = tp; if (son[u]) dfs2(son[u], tp); for (auto y : edge[u]) { int v = y.v; if (v != Fa[u] && v != son[u]) dfs2(v, v); } } int lca(int u, int v) { while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) swap(u, v); u = Fa[top[u]]; } if (dep[u] < dep[v]) return u; return v; } int dis(int u, int v) { return Dis[u] + Dis[v] - 2 * Dis[lca(u, v)]; } int Mx[N], vis[N], rt, cnt; void getrt(int u, int f) { sze[u] = 1, Mx[u] = 0; for (auto y : edge[u]) { int v = y.v, w = y.w; if (v == f || vis[v]) continue; getrt(v, u); sze[u] += sze[v], Mx[u] = max(Mx[u], sze[v]); } Mx[u] = max(Mx[u], cnt - sze[u]); if (Mx[u] < Mx[rt]) rt = u; } int fa[N]; // 点分树父亲 void init(int u) { vis[u] = 1; for (auto &y : edge[u]) { int v = y.v, w = y.w; if (vis[v]) continue; rt = 0, Mx[0] = n + 1, cnt = sze[v]; getrt(v, u); y.rt = rt; fa[rt] = u; init(rt); } } int f1[N], f2[N], sum[N]; void modify(int x, int val) { for (int u = x; u; u = fa[u]) { sum[u] += val; } for (int u = x; fa[u]; u = fa[u]) { int D = dis(fa[u], x); f1[fa[u]] += D * val; f2[u] += D * val; } } int query(int x) { int res = f1[x]; for (int u = x; fa[u]; u = fa[u]) { int D = dis(x, fa[u]); res += (f1[fa[u]] - f2[u]); res += (sum[fa[u]] - sum[u]) * D; } return res; } int solve(int u) { int res = query(u); for (auto y : edge[u]) { int v = y.v; if (query(v) < res) { return solve(y.rt); } } return res; } signed main() { cin.tie(0), cout.tie(0); ios::sync_with_stdio(0); int T; cin >> n >> T; for (int i = 1; i < n; i++) { int u, v, w; cin >> u >> v >> w; edge[u].push_back({ v, w, 0 }); edge[v].push_back({ u, w, 0 }); } dfs1(1, 0); dfs2(1, 0); rt = 0, Mx[0] = n + 1, cnt = n; getrt(1, 0); int Rt = rt; init(rt); while (T--) { int x, v; cin >> x >> v; modify(x, v); cout << solve(Rt) << "\n"; } return 0; } ``` ## 动态树分治 动态树分治,又称点分树。 我会用尽量精简生动的语言将他描述出来。 这个数据结构真的挺难的,不过真的挺有用的。 每次点分治,我们会把这一层的重心与上一层的重心连边。 得到一颗树,称为点分树。 这样原本的父子关系就被完全打乱了。 那这个对我们解决问题有什么帮助呢? **有些问题我们并不关心树的形态,比如并查集,联通块问题**。我们要求出两点间的路径,也不一定要求出 $LCA$。我们可以找一个分割点 $p$,把路径分为 $u \to p$ 和 $p \to v$。 点分树就是对原树做了这样的映射。 点分树有如下性质: 1. 它的 **深度** 为 $\log n$,与点分治的层数一样。我们可以枚举点分树上的的所有父亲。甚至可以开一个 `vector` 存下每个点子树内的点。 2. 对两点 $(u, v)$,它们在点分树上的 $lca$.一定在它们的路径上, 也就是说, $dis(u, v) = dis(u, lca) + dis(lca, v)$。注意 $dis$ 是在原树上的距离。 **计算贡献**: 以下用 $fa_x$ 表示 $x$ 在点分树上的父亲节点,$subtree(x)$ 表示 $x$ 在点分树上的子树节点集合,$A(x)$ 表示 $x$ 的所有祖先节点集合,$dis(x, y)$ 表示两点在 **原树上的距离**。 枚举所有祖先节点当做中转点。 设 $ans(i, j)$ 表示距离 $i$ 小于等于 $j$ 的点的点权和。 设以 $a$ 为中转点,由于 $a$ 在点分树的子树里的点已经被统计过了,那么要统计的是除去 $a$ 在 $x$ 这侧的子树的所有点到 $x$ 的距离小于等于 $j$ 的答案。 设 $f1(i, j)$ 表示在 $i$ 点分树的子树里的点到 $j$ 的距离小于等于 $j$ 的点权和。 $$ f1(i, j) = \sum_{x \in subtree(i) \land dis(x, i) \le j} a_x $$ 为了除去某个点的子树。 设 $f2(i, j)$ 表示在 $i$ 点分树的子树里的点到 $fa_i$ 的距离小于等于 $j$ 的点权和。 $$ f2(i, j) = \sum_{x \in subtree(i) \land dis(x, fa_i) \le j} a_x $$ 于是 $ans$ 可以计算。 $$ ans(i, j) = f1(i, j) + \sum_{x \in A(i) \land fa(x) \land dis(i, x) \le j} f1(fa_x, j - dis(i, fa_x)) - f1(x, j - dis(i, fa_x)) $$ 我们看到例题。 #### 例2.1 震波 在一片土地上有 $n$ 个城市,通过 $n-1$ 条无向边互相连接,形成一棵树的结构,相邻两个城市的距离为 $1$,其中第 $i$ 个城市的价值为 $value_i$。 不幸的是,这片土地常常发生地震,并且随着时代的发展,城市的价值也往往会发生变动。 接下来你需要在线处理 $m$ 次操作: `0 x k` 表示发生了一次地震,震中城市为 $x$,影响范围为 $k$,所有与 $x$ 距离不超过 $k$ 的城市都将受到影响,该次地震造成的经济损失为所有受影响城市的价值和。 `1 x y` 表示第 $x$ 个城市的价值变成了 $y$ 。 为了体现程序的在线性,操作中的 $x$、$y$、$k$ 都需要异或你程序上一次的输出来解密,如果之前没有输出,则默认上一次的输出为 $0$ 。 思路: 只需要处理修改操作。 直接暴力在点分树上跳父亲。 看代码吧,注意树状数组细节: ```cpp #include<bits/stdc++.h> using namespace std; //#define int long long const int N = 2e5 + 10; int n, m; vector<int > edge[N]; struct BIT{ int sze; vector<int> c; void resize(int x) { sze = x - 1; c.resize(x); } void add(int x, int y) { x++; for(;x <= sze; x += x & -x) c[x] += y; } int query(int x) { x++; int res = 0; x = min(x, sze); for(;x ; x -= x & -x) res += c[x]; return res; } }w0[N], w1[N]; int fa[N][19], dep[N], lg2[N]; void dfs0(int u, int f) { fa[u][0] = f, dep[u] = dep[f] + 1; for(auto v : edge[u]) if(v != f) dfs0(v, u); } int Lca(int x, int y) { if(dep[x] < dep[y]) swap(x, y); while(dep[x] > dep[y]) x = fa[x][lg2[dep[x] - dep[y]]]; if(x == y) return x; for(int i = 18; i >= 0; i--) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } int getdis(int x, int y) { return dep[x] + dep[y] - 2 * dep[Lca(x, y)]; } //上面部分为预处理。 int vis[N], tot, rt, dp[N], sze[N]; void getrt(int u, int f) { dp[u] = 0, sze[u] = 1; for(auto v : edge[u]) { if(v == f || vis[v]) continue; getrt(v, u); sze[u] += sze[v]; dp[u] = max(sze[v], dp[u]); } dp[u] = max(dp[u], tot - sze[u]); if(dp[u] < dp[rt]) rt = u; } int dfa[N], dsze[N]; void init(int u, int f) { dfa[u] = f;// 点分树上的父亲 vis[u] = 1; w1[u].resize(tot + 2); w0[u].resize(tot + 2);// 注意空间 for(auto v : edge[u]) { if(vis[v]) continue; rt = 0, dp[0] = n + 1, tot = sze[v]; getrt(v, u); init(rt, u); } } int a[N]; void modify(int u, int w) { for(int i = u; i; i = dfa[i]) w0[i].add(getdis(u, i), w); for(int i = u; dfa[i]; i = dfa[i]) w1[i].add(getdis(u, dfa[i]), w); } int query(int u, int k) { int res = w0[u].query(k); for(int i = u; dfa[i]; i = dfa[i]) { int dis = getdis(u, dfa[i]); if(k >= dis) res += w0[dfa[i]].query(k - dis) - w1[i].query(k - dis); } return res; } signed main() { cin.tie(0), cout.tie(0); ios::sync_with_stdio(0); cin >> n >> m; for(int i = 1; i <= n; i++) cin >> a[i]; for(int i = 1; i < n; i++) { int u, v; cin >> u >> v; edge[u].push_back(v); edge[v].push_back(u); } for(int i = 1; i <= n; i++) lg2[i] = log2(i); dfs0(1, 0); for(int i = 1; i <= 18; i++) for(int j = 1; j <= n; j++) fa[j][i] = fa[fa[j][i - 1]][i - 1]; rt = 0, dp[0] = n + 1, tot = n; getrt(1, 0); init(rt, 0); for(int i = 1; i <= n; i++) modify(i, a[i]); int ans = 0; while(m--) { int op, x, y; cin >> op >> x >> y; x ^= ans, y ^= ans; if(op == 0) { ans = query(x, y); cout << ans << "\n"; } else { modify(x, y - a[x]), a[x] = y; } } return 0; } ``` #### 练 2.1.1 [P3241 [HNOI2015] 开店](https://www.luogu.com.cn/problem/P3241) [P3241 [HNOI2015] 开店](https://www.luogu.com.cn/problem/P3241) 设 $$ f1(i, j) = \sum_{x \in subtree(i)\land w_x \le j} dis(i, x) \\ f2(i, j) = \sum_{x \in subtree(i)\land w_x \le j} dis(fa_i, x) \\ g1(i, j) = \sum_{x \in subtree(i)\land w_x \le j} 1 \\ g2(i, j) = \sum_{x \in subtree(i)\land w_x \le j} 1 $$ 注意这里是计算距离和,之前只加上了 $x$ 到 $fa_i$ 的距离,所以在查询 $p$ 点时,漏掉了 $fa_i$ 到 $p$ 这一段。 $ans(i, j) = f1(i, j) + \sum_{x \in A(i)} \{f1(fa_i, j) - f2(i, j) + (g1(fa_i, j) - g2(i, j)) \times dis(i, x)\}

可以点分树时对每个点开 vector 记录子树内 w 的值查询时二分即可。

代码: 二分:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 2e5 + 10;
int n, M;

vector<pair<int, int> > edge[N];
int a[N];
struct SGT{
    vector<pair<int, int> > v;
    void change(int x, int y) {
        v.push_back({x, y});
    }

    void init() {
        sort(v.begin(), v.end());
        int sze = v.size();
        for(int i = 1; i < sze; i++) {
            v[i].second += v[i - 1].second;
        }
    }

    int query(int x) {
        auto p = upper_bound(v.begin(), v.end(), make_pair(x, (int)1e14));
        if(p == v.begin()) return 0;
        p--;
        return (*p).second;
    }

    int query(int l, int r) {
        return query(r) - query(l - 1);
    }
}w1[N], w2[N], w3[N];

int Fa[N][20], dep[N], dis[N], lg2[N];

void dfs0(int u, int f) {
    Fa[u][0] = f, dep[u] = dep[f] + 1;
    for(int i = 1; i <= 19; i++) Fa[u][i] = Fa[Fa[u][i - 1]][i - 1];
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(v != f) dis[v] = dis[u] + w, dfs0(v, u);
    }   
}

int Lca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    while(dep[x] > dep[y]) x = Fa[x][lg2[dep[x] - dep[y]]];
    if(x == y) return x;
    for(int i = 19; i >= 0; i--) if(Fa[x][i] != Fa[y][i]) x = Fa[x][i], y = Fa[y][i];
    return Fa[x][0];
}
int getdis(int x, int y) {
    return dis[x] + dis[y] - 2 * dis[Lca(x, y)];
}

int dp[N], sze[N], vis[N], s[N], cnt, rt;

void getrt(int u, int f) {
    sze[u] = 1, dp[u] = 0;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(v == f || vis[v]) continue;
        getrt(v, u);
        sze[u] += sze[v];
        dp[u] = max(dp[u], sze[v]);
    }
    dp[u] = max(dp[u], cnt - sze[u]);
    if(dp[u] < dp[rt]) rt = u;
}

int fa[N];

void init(int u) {
    vis[u] = 1;
    s[u] = cnt;
    for(auto y : edge[u]) {
        int v = y.first, w = y.second;
        if(vis[v]) continue;
        rt = 0, dp[0] = n + 1, cnt = sze[v];
        getrt(v, u);
        fa[rt] = u;
        init(rt);
    }
}

void modify(int u) {
    for(int i = u; i ; i = fa[i]) w1[i].change(a[u], getdis(u, i));
    for(int i = u; fa[i]; i = fa[i]) w2[i].change(a[u], getdis(u, fa[i]));  
    for(int i = u; i; i = fa[i]) w3[i].change(a[u], 1);
}

int query(int u, int l, int r) {
    int res = w1[u].query(l, r);
    for(int i = u; fa[i] ; i = fa[i]) {
        res = res + w1[fa[i]].query(l, r) - w2[i].query(l, r) + getdis(u, fa[i]) * (w3[fa[i]].query(l, r) - w3[i].query(l, r)); 
    }
    return res;
}

signed main() {
    cin.tie(0), cout.tie(0);
    ios::sync_with_stdio(0);
    int T;
    cin >> n >> T >> M;
    for(int i = 1; i <= n; i++) {
        cin >> a[i]; a[i]++;
    }
    for(int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        edge[u].push_back({v, w});
        edge[v].push_back({u, w});
    }
    dfs0(1, 0);
    rt = 0, dp[0] = n + 1, cnt = n;
    getrt(1, 0);
    init(rt);
    for(int i = 1; i <= n; i++) lg2[i] = log2(i);
    for(int i = 1; i <= n; i++) modify(i);
    for(int i = 1; i <= n; i++) w1[i].init(), w2[i].init(), w3[i].init();
    int ans = 0;
    while(T--) {
        int u, a, b;
        cin >> u >> a >> b;
        int L = min((a + ans) % M, (b + ans) % M) + 1, R = max((a + ans) % M, (b + ans) % M) + 1;
        cout << (ans = query(u, L, R)) << "\n";
    }
    return 0;
} 

动态开店线段树版,只有结构体有变化

struct SGT {
    vector<pair<int, int> > v;
    void change(int x, int y) { v.push_back({ x, y }); }

    void init() {
        sort(v.begin(), v.end());
        int sze = v.size();
        for (int i = 1; i < sze; i++) {
            v[i].second += v[i - 1].second;
        }
    }

    int query(int x) {
        auto p = upper_bound(v.begin(), v.end(), make_pair(x, (int)1e14));
        if (p == v.begin())
            return 0;
        p--;
        return (*p).second;
    }

    int query(int l, int r) { return query(r) - query(l - 1); }
} w1[N], w2[N], w3[N];

练 2.1.2 P5311 [Ynoi2011] 成都七中

P5311 [Ynoi2011] 成都七中

题意:

给你一棵 n 个节点的树,每个节点有一种颜色,有 m 次查询操作。

查询操作给定参数 l\ r\ x,需输出:

将树中编号在 [l,r] 内的所有节点保留,x 所在连通块中颜色种类数。

每次查询操作独立。

Sol

点分树有性质:树上任意一个联通块,存在一个在点分树上深度最小的点,并且整个联通块都在这个点的子树当中。

证明:

问题变成:在一棵树中,从根节点出发,只经过 [l,r] 范围内的点,可以到达的颜色数。

使用反证法。

假设点分树上最浅的节点叫做 p,联通块中有一个点 q 不在 p 的子树内。由于点分树上 p 的子树也是原树中的一个联通块,并且所有 p 子树之外的点到达 p 都必须经过一个比p更浅的节点,所以q到p的路径上最浅的点一定比 p 浅,而同时这个点一定也在联通块内,这违反了“ p 是最浅的节点”。

所以我们可以把每一个询问 l,r,x,归在和 x 在同一个联通块中的点分树上最浅的节点。然后就只需要对点分树上每个点遍历一遍子树,分开来处理。

注意要判断这个点是否能走到点分树上的祖先节点。

我们记录一下每一个节点到(点分树上的)每个祖先的路径上的编号最大的点和编号最小的点,分别设成 LR

我们发现,只有对于 x 节点拥有的一个询问 (l,r) ,只有 L\ge lR \le r的节点才能对答案有贡献。

将询问离线一波,第一维排序第二维树状数组维护即可解决。

#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int n;
vector<int> edge[N];
int c[N];
void add(int x, int y) {
    for (; x <= n; x += x & -x) c[x] += y;
}

int query(int x) {
    if (!x)
        return 0;
    int res = 0;
    for (; x; x -= x & -x) res += c[x];
    return res;
}

void del(int x) {
    for (; x <= n; x += x & -x) c[x] = 0;
}

struct node {
    int l, r, op;
};

int a[N];
int ans[N];
vector<node> q[N];

int vis[N], dp[N], sze[N], rt, cnt;
void getrt(int u, int f) {
    dp[u] = 0, sze[u] = 1;
    for (auto v : edge[u]) {
        if (v == f || vis[v])
            continue;
        getrt(v, u);
        dp[u] = max(dp[u], sze[v]);
        sze[u] += sze[v];
    }
    dp[u] = max(dp[u], cnt - sze[u]);
    if (dp[u] < dp[rt])
        rt = u;
}
int mn[N], mx[N], col[N];
vector<node> t;
void getres(int u, int f) {
    mn[u] = min(mn[f], u), mx[u] = max(mx[f], u);
    t.push_back({ mn[u], mx[u], -a[u] });
    for (auto x : q[u]) {
        if ((!ans[x.op]) && x.l <= mn[u] && mx[u] <= x.r)
            t.push_back({ x.l, x.r, x.op });
    }
    for (auto v : edge[u]) {
        if (vis[v] || v == f)
            continue;
        getres(v, u);
    }
}

bool cmp(node x, node y) {
    if (x.l != y.l)
        return x.l > y.l;
    return x.op < y.op;
}

void solve(int u) {
    vis[u] = 1;
    t.clear();
    mx[0] = -1e9, mn[0] = 1e9;
    getres(u, 0);
    sort(t.begin(), t.end(), cmp);
    for (auto x : t) {
        if (x.op < 0) {
            x.op *= -1;
            if (!col[x.op])
                add(x.r, 1), col[x.op] = x.r;
            else if (x.r < col[x.op]) {
                add(col[x.op], -1);
                add(x.r, 1);
                col[x.op] = x.r;
            }
        } else {
            ans[x.op] = query(x.r) - query(x.l - 1);
        }
    }
    for (auto x : t) {
        if (x.op < 0)
            del(x.r), col[-x.op] = 0;
    }

    for (auto v : edge[u]) {
        if (vis[v])
            continue;
        rt = 0, dp[0] = n + 1, cnt = sze[v];
        getrt(v, u);
        solve(rt);
    }
}

int main() {
    int T;
    cin >> n >> T;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }
    for (int i = 1; i <= T; i++) {
        int l, r, x;
        cin >> l >> r >> x;
        q[x].push_back({ l, r, i });
    }
    rt = 0, dp[0] = n + 1, cnt = n;
    getrt(1, 0);
    solve(rt);
    for (int i = 1; i <= T; i++) cout << ans[i] << "\n";
    return 0;