题解 [模拟赛 2023.02.21] 攀比之心

· · 个人记录

首先有博弈结论:

证明:

现在我们只需要求出所有情况减去最长的从 1 出发的链只有一条的情况即可。

考虑树形 dp。设 dp_{u, i} 表示 u 子树内选包含 u 的连通块,离 u 最大距离 = i 的方案数。

暴力转移可以每次加入一个新儿子,长链剖分优化即可。注意需要打乘积的 lazy 标记。

时间复杂度为 O(n \log mod)

代码:

#include <stdio.h>

typedef long long ll;

typedef struct {
    int nxt;
    int end;
} Edge;

const int mod = 1e9 + 7;
int cnt;
int head[200007], fa[200007], depth[200007], max_depth[200007], hs[200007], top[200007];
ll dp1[400007], *p, mul1[400007], *q, pre[200007], *dp2[200007], *mul2[200007];
Edge edge[400007];

inline void init(int n){
    cnt = 0;
    p = dp1;
    q = mul1;
    for (register int i = 1; i <= n; i++){
        head[i] = 0;
    }
}

inline int read(){
    int ans = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9'){
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9'){
        ans = ans * 10 + (ch ^ 48);
        ch = getchar();
    }
    return ans;
}

inline void add_edge(int start, int end){
    cnt++;
    edge[cnt].nxt = head[start];
    head[start] = cnt;
    edge[cnt].end = end;
}

void dfs1(int u, int father){
    fa[u] = father;
    depth[u] = depth[father] + 1;
    max_depth[u] = hs[u] = 0;
    for (register int i = head[u]; i != 0; i = edge[i].nxt){
        int x = edge[i].end;
        if (x != father){
            dfs1(x, u);
            if (max_depth[u] < max_depth[x]){
                max_depth[u] = max_depth[x];
                hs[u] = x;
            }
        }
    }
    max_depth[u]++;
}

void dfs2(int u, int cur_top){
    top[u] = cur_top;
    if (hs[u] != 0){
        dfs2(hs[u], cur_top);
        for (register int i = head[u]; i != 0; i = edge[i].nxt){
            int x = edge[i].end;
            if (x != fa[u] && x != hs[u]) dfs2(x, x);
        }
    }
}

inline void pushdown(int x, int y){
    if (mul2[x][y] != 1){
        int yi = y + 1;
        mul2[x][yi] = mul2[x][yi] * mul2[x][y] % mod;
        dp2[x][yi] = dp2[x][yi] * mul2[x][y] % mod;
        mul2[x][y] = 1;
    }
}

void dfs3(int u){
    dp2[u][1] = mul2[u][1] = 1;
    if (hs[u] != 0){
        dp2[hs[u]] = dp2[u] + 1;
        mul2[hs[u]] = mul2[u] + 1;
        dfs3(hs[u]);
        for (register int i = head[u]; i != 0; i = edge[i].nxt){
            int x = edge[i].end;
            if (x != fa[u] && x != hs[u]){
                int t = max_depth[x] + 1;
                ll sum1 = 1, sum2 = 1;
                dp2[x] = p;
                p += t;
                mul2[x] = q;
                q += t;
                dfs3(x);
                for (register int j = 2; j <= t; j++){
                    pushdown(u, j);
                    pushdown(x, j - 1);
                    sum1 = (sum1 + dp2[u][j]) % mod;
                    dp2[u][j] = (sum1 * dp2[x][j - 1] % mod + sum2 * dp2[u][j] % mod) % mod;
                    sum2 = (sum2 + dp2[x][j - 1]) % mod;
                }
                if (t < max_depth[u]){
                    t++;
                    mul2[u][t] = mul2[u][t] * sum2 % mod;
                    dp2[u][t] = dp2[u][t] * sum2 % mod;
                }
            }
        }
    }
}

inline ll quick_pow(ll x, ll p, ll mod){
    ll ans = 1;
    while (p){
        if (p & 1) ans = ans * x % mod;
        x = x * x % mod;
        p >>= 1;
    }
    return ans;
}

int main(){
    int t = read();
    for (register int i = 1; i <= t; i++){
        int n = read();
        ll ans = 1;
        init(n);
        for (register int j = 1; j < n; j++){
            int u = read(), v = read();
            add_edge(u, v);
            add_edge(v, u);
        }
        dfs1(1, 0);
        dfs2(1, 1);
        for (register int j = 0; j <= max_depth[1]; j++){
            pre[j] = 1;
        }
        for (register int j = head[1]; j != 0; j = edge[j].nxt){
            int x = edge[j].end;
            ll sum = 1;
            dp2[x] = p;
            p += max_depth[x] + 1;
            mul2[x] = q;
            q += max_depth[x] + 1;
            dfs3(x);
            for (register int k = 1; k <= max_depth[x]; k++){
                pushdown(x, k);
                pre[k] = pre[k] * (sum + dp2[x][k]) % mod;
                if (k > 1) pre[k] = pre[k] * quick_pow(sum, mod - 2, mod) % mod;
                sum = (sum + dp2[x][k]) % mod;
            }
            ans = ans * sum % mod;
        }
        for (register int j = 1; j <= max_depth[1]; j++){
            pre[j] = pre[j] * pre[j - 1] % mod;
        }
        for (register int j = head[1]; j != 0; j = edge[j].nxt){
            int x = edge[j].end;
            ll sum = 1;
            for (register int k = 1; k <= max_depth[x]; k++){
                ans = ((ans - pre[k - 1] * dp2[x][k] % mod * quick_pow(sum, mod - 2, mod) % mod) % mod + mod) % mod;
                sum = (sum + dp2[x][k]) % mod;
            }
        }
        printf("Case #%d: %lld\n", i, ans);
    }
    return 0;
}