题解:P16355 「Diligent-OI R3 C」彼方へ、名もなき海辺より

· · 题解

赛时一发过了说是。

题面描述的很清楚了。在此不过多描述。

每个节点恰好选一条边,故总边数为 n,点数为 n。这样的图由若干个连通分量组成,每个分量恰好有一个环(因为 n 个点 n 条边,每个连通分量的边数等于点数时有且只有一个环)。

因此,连通块个数就是环的个数。

我们可以看全集为所有可能的选择方案,方案总数:

M = \prod_{u=1}^{n} \deg(u),\quad \deg(u) = \text{son}(u) + \text{depth}(u)

其中 \text{depth}(u) 表示节点 u 的深度(根深度为 0)。

对于每个节点 u,我们想让它成为所在环中最小节点的方案数。累加这些就是答案。

考虑以 1 为根的有根树。对于节点 u,它在环中的情况依赖于它的后代和祖先之间的边选择。通过分析,可以推导出递推关系。

定义 S[u] 为以 u 为根的子树中,节点 u 成为环上最小节点的方案数与总方案数的比值。利用树上的独立性和概率转移,可以得到递推式:

S[u] = \frac{1}{\deg(u)} \sum_{v\in \text{son}(u)} \left( \frac{1}{\deg(v)} + S[v] \right)

其中 1/\deg(v) 的项来自节点 v 直接连向父亲 u 的情况。

按照后序遍历(从叶子向上)计算 S[u],最终期望连通块数:

E = \sum_{u=1}^{n} S[u]

综上可得知总方案数为 M \times E \pmod{998244353}

时间复杂度 O(n\log MOD)(快速幂求逆元)或 O(n)(线性求逆元),空间 O(n),可处理 n\le 5\times10^5


#include<bits/stdc++.h>
using namespace std;
const int MOD=998244353;
const int N=500005;
int n,deg[N],S[N];
vector<int> g[N];
int dep[N],pa[N],order[N],cnt;
int qpow(int a,int b){
    int r=1;
    for(;b;b>>=1,a=1LL*a*a%MOD) if(b&1) r=1LL*r*a%MOD;
    return r;
}
void dfs(int u,int f){
    pa[u]=f; dep[u]=dep[f]+1;
    for(int v:g[u]) if(v!=f){
        deg[u]++;   // 儿子数
        dfs(v,u);
    }
}
void bfs(){
    queue<int> q; q.push(1); cnt=0;
    while(!q.empty()){
        int u=q.front(); q.pop();
        order[++cnt]=u;
        for(int v:g[u]) if(v!=pa[u]) q.push(v);
    }
}
int main(){
    ios::sync_with_stdio(0); cin.tie(0);
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v; cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dep[0]=-1; dfs(1,0);
    // 度数 = 儿子数 + 深度 (根深度0)
    for(int i=1;i<=n;i++) deg[i] += dep[i];
    int M=1;
    for(int i=1;i<=n;i++) M=1LL*M*deg[i]%MOD;
    // 预处理逆元
    vector<int> inv(n+1);
    for(int i=1;i<=n;i++) inv[i]=qpow(deg[i],MOD-2);
    bfs();
    // 逆序处理即从叶子向上
    for(int i=n;i>=1;i--){
        int u=order[i];
        int sum=0;
        for(int v:g[u]) if(v!=pa[u]){
            sum = (sum + inv[v] + S[v]) % MOD;
        }
        S[u] = 1LL * sum * inv[u] % MOD;
    }
    int E=0;
    for(int i=1;i<=n;i++) E=(E+S[i])%MOD;
    int ans=1LL*M*E%MOD;
    cout<<ans<<'\n';
    return 0;
}