P3899 解题报告

· · 题解

题意:

有一棵树,求 ab 均为 c 的祖先,且 ab 的距离不超过 k 的方案数

思路:

ab 在树上的关系进行分类讨论:

  1. ba 的祖先,则 ca 的子树中的一个节点。根据乘法原理,方案数为 \min(k, dep_a-1) * (siz_a-1)。其中,\min(k, dep_a-1) 表示可取的 b 的方案数,siz_a-1 表示 a 的子树的节点个数(不包括 a 本身),即可取的 c 的方案数。

  2. ab 的祖先,则 cb 的子树中的一个节点。此时,方案数为所有深度在 [dep_a + 1, dep_a + k] 之间的 b 的贡献,即所有 siz_b-1 的和。

这两种情况之和即为答案。

法一:可持久化线段树

  1. 记录每个节点的 \operatorname{dfs} 序。

  2. 对于每个 \operatorname{dfs} 序,建立一棵线段树,叶子节点下标即为深度,区间方案数的前缀和记为 \operatorname{cnt}

  3. 查询 [dep_a + 1, dep_a + k] 的区间和,即为查询版本 dfn_a 和版本 dfn_a + siz_a - 1 的区间 \operatorname{cnt} 之差。

#include <bits/stdc++.h>
#define il inline
namespace Fast_IO {
    template <typename T> il void read(T &x) {
        x = 0; int f = 0; char ch = getchar();
        while (!isdigit(ch))f |= (ch == '-'), ch = getchar();
        while (isdigit(ch))x = x * 10 + (ch - 48), ch = getchar();
        x = f ? -x : x;
    }
    template <typename T, typename ...Args>
    il void read(T &x, Args& ...args) {read(x), read(args...);}
    template <typename T> il void write(T x, char c = '\n') {
        if (x) {
            if (x < 0)x = -x, putchar('-');
            char a[30]; short l;
            for (l = 0; x; x /= 10)a[l ++] = x % 10 ^ 48;
            for (l --; l >= 0; l --)putchar(a[l]);
        } else putchar('0');
        putchar(c);
    }
    template <typename T, typename ...Args>
    il void write(T x, Args ...args) {write(x), write(args...);}
} using namespace Fast_IO;
using namespace std;

#define int long long
const int Maxn = 300005;
struct Segment_Tree {int l, r, cnt;} tree[Maxn << 5];
struct Edge {int to, nxt;} edge[Maxn << 1];
int n, Q, dep[Maxn], maxdep, dfn[Maxn], tim, siz[Maxn];
int head[Maxn], tot, rt[Maxn], cnt;
void update(int &u, int pre, int l, int r, int x, int k) { //点修 
    u = ++ cnt;
    tree[u] = tree[pre];
    tree[u].cnt += k;
    if (l == r) return ;
    int mid = (l + r) >> 1;
    if (x <= mid)update(tree[u].l, tree[pre].l, l, mid, x, k);
    else update(tree[u].r, tree[pre].r, mid + 1, r, x, k);
}
int query(int u, int v, int l, int r, int x, int y) { //区查 
    if (y < l || x > r) return 0;
    if (x <= l && r <= y) return tree[v].cnt - tree[u].cnt;
    int mid = (l + r) >> 1;
    return query(tree[u].l, tree[v].l, l, mid, x, y) + query(tree[u].r, tree[v].r, mid + 1, r, x, y);
}
il void add(int u, int v) {
    edge[++ tot].to = v;
    edge[tot].nxt = head[u];
    head[u] = tot;
}
void dfs1(int u, int f) { //求树的 dfs 序、大小、深度、最大深度 
    dep[u] = dep[f] + 1;
    dfn[u] = ++ tim;
    siz[u] = 1;
    for (int i = head[u]; i; i = edge[i].nxt) {
        int v = edge[i].to;
        if (v == f) continue;
        dfs1(v, u);
        maxdep = max(maxdep, dep[v]);
        siz[u] += siz[v];
    }
}
void dfs2(int u, int f) { //建树,siz[u]-1 为节点 u 的贡献值 
    update(rt[dfn[u]], rt[dfn[u] - 1], 1, maxdep, dep[u], siz[u] - 1);
    for (int i = head[u]; i; i = edge[i].nxt)
        if (edge[i].to != f)
            dfs2(edge[i].to, u);
}
signed main() {
    read(n, Q);
    for (int i = 1, u, v; i < n; i ++)
        read(u, v), add(u, v), add(v, u);
    dfs1(1, 0), dfs2(1, 0);
    while (Q --) {
        int a, k; read(a, k);
        int ans1 = min(k, dep[a] - 1) * (siz[a] - 1); //情况1 
        int ans2 = query(rt[dfn[a]], rt[dfn[a] + siz[a] - 1], 1, maxdep, dep[a] + 1, dep[a] + k); //情况2 
        write(ans1 + ans2);
    }
    return 0;
}

法二:线段树合并

  1. 对每个节点 a,以 dep_a 为下标,以 siz_b - 1 为权值,维护一棵权值线段树。

#include <bits/stdc++.h>
#define il inline
namespace Fast_IO {
    template <typename T> il void read(T &x) {
        x = 0; int f = 0; char ch = getchar();
        while (!isdigit(ch))f |= (ch == '-'), ch = getchar();
        while (isdigit(ch))x = x * 10 + (ch - 48), ch = getchar();
        x = f ? -x : x;
    }
    template <typename T, typename ...Args>
    il void read(T &x, Args& ...args) {read(x), read(args...);}
    template <typename T> il void write(T x, char c = '\n') {
        if (x) {
            if (x < 0)x = -x, putchar('-');
            char a[30]; short l;
            for (l = 0; x; x /= 10)a[l ++] = x % 10 ^ 48;
            for (l --; l >= 0; l --)putchar(a[l]);
        } else putchar('0');
        putchar(c);
    }
    template <typename T, typename ...Args>
    il void write(T x, Args ...args) {write(x), write(args...);}
} using namespace Fast_IO;
using namespace std;

#define int long long
const int Maxn = 300005;
struct Segment_Tree {int l, r, sum;} tree[Maxn << 6];
struct Edge {int to, nxt;} edge[Maxn << 1];
int rt[Maxn], cnt, dep[Maxn], siz[Maxn], n, Q, head[Maxn], tot;
il void add(int u, int v) {
    edge[++ tot].to = v;
    edge[tot].nxt = head[u];
    head[u] = tot;
}
void update(int &u, int l, int r, int x, int k) {
    if (!u)u = ++ cnt;
    if (l == r) {
        tree[u].sum += k;
        return ;
    }
    int mid = (l + r) >> 1;
    if (x <= mid)update(tree[u].l, l, mid, x, k);
    else update(tree[u].r, mid + 1, r, x, k);
    tree[u].sum = tree[tree[u].l].sum + tree[tree[u].r].sum;
}
int query(int u, int l, int r, int x, int y) {
    if (y < l || x > r) return 0;
    if (x <= l && r <= y) return tree[u].sum;
    int mid = (l + r) >> 1;
    return query(tree[u].l, l, mid, x, y) + query(tree[u].r, mid + 1, r, x, y);
}
int merge(int u, int v) {
    if (!u || !v) return u + v;
    int now = ++ cnt;
    tree[now].sum = tree[u].sum + tree[v].sum;
    tree[now].l = merge(tree[u].l, tree[v].l);
    tree[now].r = merge(tree[u].r, tree[v].r);
    return now;
}
void dfs(int u, int f) {
    dep[u] = dep[f] + 1;
    siz[u] = 1;
    for (int i = head[u]; i; i = edge[i].nxt) {
        int v = edge[i].to;
        if (v == f) continue;
        dfs(v, u);
        siz[u] += siz[v];
        rt[u] = merge(rt[u], rt[v]);
    }
    update(rt[u], 1, n, dep[u], siz[u] - 1);
}
signed main() {
    read(n, Q);
    for (int i = 1, u, v; i < n; i ++)
        read(u, v), add(u, v), add(v, u);
    dfs(1, 0);
    while (Q --) {
        int a, k; read(a, k);
        int ans1 = min(k, dep[a] - 1) * (siz[a] - 1);
        int ans2 = query(rt[a], 1, n, dep[a] + 1, min(dep[a] + k, n));
        write(ans1 + ans2);
    }
    return 0;
}