[题解]P8935 [JRKSJ R7] 茎

· · 题解

更好的阅读体验

思路

首先思考 x = 1 的做法。有一个很简单的 dp,定义 dp_{u,i} 表示在 u 子树中操作了 i 次的方案数,需要注意的是这里并不强制要把 u 子树全部切掉。

考虑转移。因为 u 被切掉过后其任何子树内都将无法进行操作,所以如果要操作 u 则这一步必定是最后一步。于是先合并子树信息,有:

dp'_{u,i + j} \leftarrow \binom{i + j}{j}dp_{u,i} \times dp_{v,j}

然后再考虑 u 是否被操作,有:

dp'_{u,i} = dp_{u,i - 1} + dp_{u,i}

接下来考虑 x \neq 1 的做法。注意到从 1 \leadsto x 路径上的任意选一个点操作都会让 x 被切掉,不妨把这条链单独拎出来进行 dp。定义 f_{u,i} 表示在 1 \leadsto x 这条路径上,只选择在 1 \leadsto u 上的点上进行操作,并将整棵树切完,同时切下 u 之前操作了 i 步的方案数。考虑 (u,v) 边的转移:

因为第 k 次操作必须操作 u,所以在转移的时候不能转移不选择 u 的情况。转移用后缀和优化可以做到 \Theta(n^2)

Code

#include <bits/stdc++.h>
#define re register
#define int long long
#define Add(a,b) (((a) + (b)) % mod)
#define Mul(a,b) ((a) * (b) % mod)
#define chAdd(a,b) (a = Add(a,b))
#define chMul(a,b) (a = Mul(a,b))

using namespace std;

const int N = 510;
const int mod = 1e9 + 7;
int n,k,p;
int fp[N],sz[N],son[N],tmp[N];
int fac[N],infac[N],dp[N][N],f[N][N];
vector<int> g[N];

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline int qmi(int a,int b){
    int res = 1;
    while (b){
        if (b & 1) chMul(res,a);
        chMul(a,a); b >>= 1;
    } return res;
}

inline void init(int n){
    fac[0] = 1;
    for (re int i = 1;i <= n;i++) fac[i] = Mul(fac[i - 1],i);
    infac[n] = qmi(fac[n],mod - 2);
    for (re int i = n - 1;~i;i--) infac[i] = Mul(infac[i + 1],i + 1);
}

inline int C(int n,int m){
    if (n < m) return 0;
    else return Mul(fac[n],Mul(infac[m],infac[n - m]));
}

inline void dfs(int u,int fa){
    fp[u] = fa,dp[u][0] = 1;
    if (u == p) son[fa] = u;
    for (int v:g[u]){
        if (v == fa) continue;
        dfs(v,u);
        if (son[v]) son[u] = v;
        for (re int i = 0;i <= sz[u] + sz[v];i++) tmp[i] = 0;
        for (re int i = 0;i <= sz[u];i++){
            for (re int j = 0;j <= sz[v];j++) chAdd(tmp[i + j],Mul(C(i + j,j),Mul(dp[u][i],dp[v][j])));
        } sz[u] += sz[v];
        for (re int i = 0;i <= sz[u];i++) dp[u][i] = tmp[i];
    } sz[u]++;
    for (re int i = sz[u];i;i--) chAdd(dp[u][i],dp[u][i - 1]);
}

signed main(){
    n = read(),k = read(),p = read();
    init(n);
    for (re int i = 1,a,b;i < n;i++){
        a = read(),b = read();
        g[a].push_back(b);
        g[b].push_back(a);
    } dfs(1,0);
    int u = 1,num = 0;
    f[1][0] = 1;
    while (u){
        num++;
        for (re int i = 0;i <= num + 2;i++) tmp[i] = 0;
        for (re int i = num;~i;i--) tmp[i] = Add(tmp[i + 1],f[fp[u]][i]);
        if (u != 1){
            for (re int i = 0;i <= num;i++){
                f[u][i] = tmp[i];
                if (u != p) chAdd(f[u][i],f[fp[u]][i]);
            }
        }
        for (int v:g[u]){
            if (v == fp[u] || v == son[u]) continue;
            for (re int i = 0;i <= num + sz[v];i++) tmp[i] = 0;
            for (re int i = 0;i <= num;i++){
                for (re int j = 0;j <= sz[v];j++) chAdd(tmp[i + j],Mul(C(i + j,j),Mul(dp[v][j],f[u][i])));
            } num += sz[v];
            for (re int i = 0;i <= num;i++) f[u][i] = tmp[i];
        } u = son[u];
    } printf("%lld",f[p][k - 1]);
    return 0;
}