【题解】牛客网CSP-S集训4 路径计数机

· · 个人记录

计数万古如长夜啊,大样例老是过不去啊

什么?数漏了,再开个数组把漏掉的加上就可以了

什么?又数重复了,再开个数组把重复的容斥掉即可

什么?转移速度太慢,再开一个数组记录前缀和就优化了

什么?函数名还有哪些啊,都用完了

恶心的树形\text{DP}, 我使用了11\text{DP}数组终于\text{AC}

\text{f[x][i], g[x][i], F[x][i], G[x][i]} \text{h[x][i], Sum[x][i], S[x], R[x], d[x][i],sum\_p[x], sum\_q[x]}

解释下含义

f[x][i]:从x出发向下走长度为i的链的个数

F[x][i]:从x出发长度为i的链的个数

g[x][i]:从x出发向上走长度为i的链的个数

G[x][i]:删除以x为根的子树,长度为i的链的个数\times 2(咳咳,这个和g[x][i]没有任何关系)

h[x][i]:从x的子树外,到x的子树内,长度为i的链的个数

d[x][i]:在x的子树内,经过x的长度为i的链的个数\times 2

Sum[x][i]:在x的子树内,F[][i]数组的和

S[i]:整棵树,F[][i]的数组的和

sum_p[x]:在x的子树内,d[][p]的和\times \dfrac{1}{2}

sum_q[x]:在x的子树内,d[][q]的和\times \dfrac{1}{2}

R[x]:在x的子树内,对于所有不相交的长度为p和长度为q的链,设这两条链的端点分别为a, bc, d,满足LCA(a, b, c, d) \notin \{a, b, c, d\},的对数

\color{Red}{\text{约定x表示任意树上节点,v是x的孩子}}
f[x][0] = d[x][0] = 1 f[x][i] = \sum f[v][i - 1] F[v][i] = f[v][i] + F[x][i - 1] - f[v][i - 2] [i \ge 2]

注意上面那个[i \ge 2]是艾弗森括号

g[v][i] = F[x][i - 1] - f[v][i - 2] d[x][i] = f[v][j - 1] \times f[x][i - j], j \in [1, i - 1], d[x][i] += f[x][i], d[x][i] *= 2

注意,上柿子的前半部分是在每次f数组更新前更新,后半部分是遍历子节点结束后进行

Sum[x][i] = \sum Sum[v][i] + F[x][i] S[i] = \sum F[x][i] h[x][i] = \sum g[x][j] \times f[x][i - j], j \in [1, i] G[x][i] = S[i] - Sum[x][i] - h[x][i] sum\_p[x] = \sum sum\_p[v] + d[x][p] \times \dfrac{1}{2} sum\_q[x] = \sum sum\_q[v] + d[x][q] \times \dfrac{1}{2} R[x] = \sum R[v] + sum\_p[x] \times sum\_q[v] + sum\_q[x] \times sum\_p[v]

注意,柿子后面的都是在每次sum\_p[x]sum\_q[x]更新前更新

最后激动人心的Ans = \sum d[x][p] \times G[x][q] + \sum d[x][q] \times G[x][p] - R[1] \times 4

完结撒花ヾ(✿゚▽゚)ノ

//ID: LRL52  Date: 2019.11.5
//计数万古如长夜啊
//恶心的树形DP, 我使用了11个DP数组终于AC了 
#define rep(i, a, b) for(int i = (a); i <= (b); ++i)
#define ee(i, a) for(int i = head[a]; i; i = e[i].nxt)
#include<bits/stdc++.h>
using namespace std;
const int N = 3055, M = 2055; char ss[1 << 21], *A = ss, *B = ss, cc;
inline char gc(){return A == B && (B = (A = ss) + fread(ss, 1, 1 << 21, stdin), A == B) ? EOF : *A++;}
template<class T>inline void rd(T &x){
    int f = 1; x = 0, cc = gc(); while(cc < '0' || cc > '9'){if(cc == '-') f = -1; cc = gc();}
    while(cc >= '0' && cc <= '9'){x = x * 10 + (cc ^ 48); cc = gc();} x *= f;
}
#define int long long
int n, ct, P, Q, K;
int head[N], f[N][N], d[N][N], Sum[N][N], S[N], G[N][N], F[N][N], h[N][N], g[N][N];
int R[N], sum_p[N], sum_q[N];
struct edge{
    int v, nxt;
}e[N << 1];

inline void adde(int from, int to){
    e[++ct] = (edge){to, head[from]};
    head[from] = ct;
}

void dfs(int x, int far){
    f[x][0] = d[x][0] = 1;
    ee(I, x){
        int v = e[I].v;
        if(v == far) continue;
        dfs(v, x);
        for(int i = 1; i <= K; ++i){
            for(int j = 1; j < i; ++j){
                d[x][i] += f[v][j - 1] * f[x][i - j];
            }
        }
        for(int i = 1; i <= K; ++i)
            f[x][i] += f[v][i - 1];
    }
    for(int i = 1; i <= K; ++i){
        d[x][i] += f[x][i];
        d[x][i] *= 2;
    }
}

void dfs2(int x, int far){
    ee(I, x){
        int v = e[I].v;
        if(v == far) continue;
        F[v][0] = 1;
        for(int i = 1; i <= K; ++i){
            F[v][i] = f[v][i] + F[x][i - 1];
            g[v][i] = F[x][i - 1];
            if(i >= 2){
                F[v][i] -= f[v][i - 2];
                g[v][i] -= f[v][i - 2];
            }
        }
        dfs2(v, x);
    }
}

void dfs3(int x, int far){
    for(int i = 0; i <= K; ++i) Sum[x][i] = F[x][i];
    for(int i = 1; i <= K; ++i){
        for(int j = 1; j <= i; ++j){
            h[x][i] += g[x][j] * f[x][i - j];
        }
    }
    ee(I, x){
        int v = e[I].v;
        if(v == far) continue;
        dfs3(v, x);
        R[x] += R[v];
        R[x] += sum_p[x] * sum_q[v] + sum_q[x] * sum_p[v];
        sum_p[x] += sum_p[v];
        sum_q[x] += sum_q[v];
        for(int i = 0; i <= K; ++i)
            Sum[x][i] += Sum[v][i];
    }
    sum_p[x] += d[x][P] / 2;
    sum_q[x] += d[x][Q] / 2;
    for(int i = 0; i <= K; ++i)
        G[x][i] = S[i] - Sum[x][i] - h[x][i];
}

#undef int
int main(){
#ifdef LRL52
    freopen("b.in", "r", stdin);
#endif
    //freopen("T2.out", "w", stdout);
#define int long long
    rd(n), rd(P), rd(Q); int x, y;
    K = max(P, Q);
    rep(i, 1, n - 1){
        rd(x), rd(y);
        adde(x, y), adde(y, x);
    }
    dfs(1, -1);
    for(int i = 0; i <= K; ++i) F[1][i] = f[1][i];
    dfs2(1, -1);
    for(int i = 0; i <= K; ++i){
        for(int j = 1; j <= n; ++j){
            S[i] += F[j][i];
        }
    }
    dfs3(1, -1);
    int ans = 0;
    for(int i = 1; i <= n; ++i){
        ans += d[i][P] * G[i][Q];
        ans += d[i][Q] * G[i][P];
    }
    ans -= R[1] * 4;
    printf("%lld\n", ans);
    //printf("%.3lf M\n", (double)(&cur2 - &cur1) / (1 << 20));
    return 0;
}