P9340 JOISC2023 D3T3 题解

· · 题解

传送门:P9340

题目大意:

给定一棵 n 个点的树以及一个序列 a_{1,\dots,m}q 次询问包含 a_{l,\dots,r} 中所有点的最小连通块大小。

回滚莫队一点也不优美,所以这里是 O(n \log^2 n) 解法。

step1

不妨将树的根定为 1 号结点。

不难看出题目就是求区间所有点构成的虚树大小,将在区间 [l,r] 中出现过的点染黑,其它点染白,可以将虚树大小转化成两部分:

  1. 加上所有点 u 的个数,满足 u 的子树中存在至少一个黑点
  2. 减去所有点 u 的个数,满足 u 的子树之外(包括 u全部为白点

第二部分实际上就是虚树的根的深度,这是好算的,虚树的根就是区间 [l,r] 中 dfn 最小的结点和 dfn 最大的结点的 lca。

step2

难点在于第一部分。

现在来分析右端点固定为 r 时,对于一个结点 u,左端点 l 取到何值会让 u 对第一部分有贡献。记 mx_u 表示 u 的子树中在前缀 [1,r]最后一次出现时间最大的点的最后一次出现位置,容易发现当 l \in [1,mx_u] 时,u 子树中的点 a_{mx_u} 必然为黑点,当 l \in [mx_u + 1, r] 时,u 子树中必然不存在黑点;所以 u 会且仅会对区间 [1,mx_u]+1 的贡献。

考虑扫描线:枚举 r,动态维护 mx_u 以及所有左端点的答案,初始有 mx_u = 0。观察从 r-1 枚举到 ra_r 会对 mx_u 造成什么影响:就是a_r1 的路径上的所有点的 mx_u 赋值为 i 。这很容易联想到 LCT 的 access 操作,也就是说,如果暴力对于从 a_r1 的路径上的 mx_u 相同的连续段分别修改的话,根据 access 的复杂度分析,这是均摊单次 O(\log n)(可以看作每次操作都是在 O(\log n) 条重链上做颜色段均摊)。而在整条链 mx_u 相同的情况下,我们只需要知道这条链的点数即可用 BIT 快速修改。

这个过程用 LCT 即可非常方便地模拟,总复杂度 O(n \log^2 n)

代码:

//区间虚树
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)
#define eb emplace_back
#define pb push_back
#define mp make_pair
using namespace std;

inline int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x;
}

typedef long long ll;
const int N = 1e5+5;

int n, m, q, c[N], ans[N];
int siz[N], dep[N], hs[N], f[N];
int ct[N], dfn[N], nfd[N];
int mn[N][17], mx[N][17], lg2[N];

vector<int> T[N];
vector<pair<int, int> > Q[N];

void dfs1(int x, int fa) {
    f[x] = fa;
    dep[x] = dep[fa] + 1;
    siz[x] = 1;
    for (auto son : T[x]) {
        if (son == fa) continue;
        dfs1(son, x);
        siz[x] += siz[son];
        if (siz[son] > siz[hs[x]]) hs[x] = son;
    }
}

void dfs2(int x, int ctop) {
    ct[x] = ctop;
    nfd[dfn[x] = ++dfn[0]] = x;
    if (hs[x]) dfs2(hs[x], ctop);
    for (auto son : T[x]) {
        if (son == f[x] || son == hs[x]) continue;
        dfs2(son, son);
    }
}

inline int LCA(int x, int y) {
    while (ct[x] != ct[y]) {
        if (dep[ct[x]] < dep[ct[y]]) swap(x, y);
        x = f[ct[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

int BIT[N];
inline void upd(int x, int k) { for (; x <= m; x += lowbit(x)) BIT[x] += k; }
inline void upd(int l, int r, int k) { upd(l, k), upd(r + 1, -k); }
inline int qry(int x) { int ret = 0; for (; x; x -= lowbit(x)) ret += BIT[x]; return ret; } 

namespace LCT {
    struct node {
        int son[2], fa;
        int cov, col;
    } T[N];

    #define son(x, s) T[x].son[s]
    #define f(x) T[x].fa

    inline void cov(int x, int k) { if (x) T[x].cov = T[x].col = k; }
    inline void push_down(int x) { if (int &t = T[x].cov) cov(son(x, 0), t), cov(son(x, 1), t), t = 0; }
    inline void cct(int x, int y, bool loc) { son(x, loc) = y, f(y) = x; }
    inline bool get_son(int x) { return son(f(x), 1) == x; }
    inline bool isRt(int x) { return son(f(x), 0) != x && son(f(x), 1) != x; }

    inline void rotate(int x) {
        int y = f(x);
        bool loc = get_son(x);
        f(x) = f(y);
        if (!isRt(y)) {
            son(f(y), get_son(y)) = x;
        }
        cct(y, son(x, loc ^ 1), loc);
        cct(x, y, loc ^ 1);
    }

    int stk[N], top;

    inline void splay(int x) {
        int y = x;
        stk[top = 1] = y;
        while (!isRt(y)) stk[++top] = y = f(y);
        while (top) push_down(stk[top--]);
        for (y = f(x); !isRt(x); rotate(x), y = f(x))
            if (!isRt(y))
                rotate(get_son(x) == get_son(y) ? y : x);
    }

    inline void access(int x, int t) {
        int y = 0;
        for (; x; x = f(y = x)) {
            if (y) upd(T[y].col + 1, t, dep[y] - dep[x]);
            splay(x), son(x, 1) = y;
        }
        upd(T[y].col + 1, t, dep[y]);
        cov(y, t);
    }

    #undef son
    #undef f
}

int main() {
    n = read(), m = read(), q = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        T[u].pb(v);
        T[v].pb(u);
    }

    dfs1(1, 0);
    dfs2(1, 1);
    for (int i = 2; i <= n; i++) LCT::T[i].fa = f[i];

    lg2[0] = -1;
    for (int i = 1; i <= m; i++) {
        c[i] = read();
        lg2[i] = lg2[i >> 1] + 1;
        mn[i][0] = mx[i][0] = dfn[c[i]];
        for (int j = 1; j <= lg2[i]; j++) {
            mn[i][j] = min(mn[i][j - 1], mn[i - (1 << j - 1)][j - 1]);
            mx[i][j] = max(mx[i][j - 1], mx[i - (1 << j - 1)][j - 1]);
        }
    }

    for (int i = 1; i <= q; i++) {
        int l = read(), r = read();
        Q[r].eb(l, i);
    }

    for (int i = 1; i <= m; i++) {
        LCT::access(c[i], i);
        for (auto j : Q[i]) {
            int L = j.first, R = i, len = lg2[R - L + 1];
            ans[j.second] = qry(L);
            int mnd = min(mn[R][len], mn[L + (1 << len) - 1][len]);
            int mxd = max(mx[R][len], mx[L + (1 << len) - 1][len]);
            ans[j.second] -= dep[LCA(nfd[mnd], nfd[mxd])] - 1;
        }
    }

    for (int i = 1; i <= q; i++) printf("%d\n", ans[i]);
    return 0;
}