P13020 [GESP202506 八级] 遍历计数 题解

· · 题解

虽然本人思路看上去较冗长,但展现的是本人在考场上的一步步简化问题的步骤,希望能给读者带来帮助。

首先,考虑以1为起点遍历的情况。考虑树形dp。

f_uu 为顶点的子树的不同方案数
u 的儿子为 v_1,v_2,...,v_k ,则有

f_u=A_k^k \times \prod_{i = 1}^{k} \ f_{v_i}

其中 A_k^k 表示 k 的全排列数量

细节看代码(A 数组预处理即可):

#define int long long
m=1e9
void dfs(int pos,int fa){
    f[pos]=1;
    for(auto son:e[pos])
        if(son!=fa){
            dfs(son,pos);
            cnt[pos]++;
            f[pos]=f[pos]*f[son]%m;
        }
    f[pos]=f[pos]*A[cnt[pos]]%m;
}

上面的 dfs 时间复杂度为 O(n)n 个顶点肯定过不了100%

注意到:

当点 u,v 之间有边相连时,分别以 u,v 为根遍历,得到的 f 数组只有 f_u,f_v 发生改变。希望找到二者的关系
考虑任意一边 (u,v),设两点度数为 d_u,d_v

u 为节点时

f_u=A_{d_u}^{d_u} \times (\prod_{i = 1}^{d_u-1} \ f_{u_i}) \times f_v=A_{d_u}^{d_u} \times (\prod_{i = 1}^{d_u-1} \ f_{u_i}) \times A_{d_v-1}^{d_v-1} \times (\prod_{i = 1}^{d_v-1} \ f_{v_i})

v 为节点时

f_v'=A_{d_v}^{d_v} \times (\prod_{i = 1}^{d_v-1} \ f_{v_i}) \times f_u'=A_{d_v}^{d_v} \times (\prod_{i = 1}^{d_v-1} \ f_{v_i}) \times A_{d_u-1}^{d_u-1} \times (\prod_{i = 1}^{d_u-1} \ f_{u_i})

二者相比得:

\frac{ans_u}{ans_v}=\frac{f_u}{f_v'}=\frac{A_{d_v-1}^{d_v-1} \times A_{d_u}^{d_u}}{A_{d_v}^{d_v} \times A_{d_u-1}^{d_u-1}}=\frac{d_u}{d_v}

所以,可以用深搜再跑一遍,计算出每个ans

但是 模运算不支持除法

由上文知

\frac{ans_u}{ans_v}=\frac{d_u}{d_v}

ans_v=\frac{f_u}{d_u} \times d_v

\frac{f_u}{d_u}v 无关 令 λ=\frac{f_u}{d_u}ans_v=λd_v

所求为

sum=\sum_{i = 1}^{n} \ ans_i=λ\sum_{i = 1}^{n} \ d_i

注意到 \sum_{i = 1}^{n} \ d_i=2(n-1) 故所求为

2λ(n-1)

树形dp跑出 λ=\frac{f_u}{d_u} 即可

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
vector<int>e[N];
int n,u,v,A[N],f[N],cnt[N],m=1e9;
void dfs(int pos,int fa){
    f[pos]=1;
    for(auto son:e[pos])
        if(son!=fa){
            dfs(son,pos);
            cnt[pos]++;
            f[pos]=f[pos]*f[son]%m;
        }
    if(pos!=1) f[pos]=f[pos]*A[cnt[pos]]%m;//f[1]即为上文的λ
    else f[pos]=f[pos]*A[cnt[pos]-1]%m;
}
signed main(){
    cin>>n;
    if(n==1){cout<<1;return 0;} //n=1时没有边,一切免谈
    A[0]=1;
    for(int i=1;i<=1e5;i++) A[i]=A[i-1]*i%m;
    for(int i=1;i<n;i++){
        cin>>u>>v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1,0);
    cout<<f[1]*(n-1)*2%m;
    return 0;
}