题解:P15649 [省选联考 2026] 找寻者 / recollector

· · 题解

题意简述

以结点 1 为根。自底向上地为每个非叶结点 u 选重儿子:在儿子们各自的重链长度 l_1,\dots,l_k 确定后,u 以正比于 l_i 的概率选 v_i,于是 u 的重链长度变为 1+l_i。求每个结点到根路径上轻边数量的期望之和,对 998244353 取模。

解题思路

先把贡献拆到每条边上。结点 x 到根的路径经过边 (u,c)c 为儿子)当且仅当 x 在子树 c 内,这样的 x\operatorname{siz}(c) 个。该边是轻边当且仅当 u 没选 c。记 P(u\to c)u 选中儿子 c 的概率,由期望的线性性,

\text{ans}=\sum_{(u,c)}\operatorname{siz}(c)\bigl(1-P(u\to c)\bigr).

接着求每条边的选择概率。设 u 的各儿子重链长度为随机变量 L_1,\dots,L_k(不同子树相互独立),S=\sum L_j,则 P(u\to c)=\mathbb E\left[\frac{L_c}{S}\right]。分母里的求和是难点,用积分恒等式 \frac1S=\int_0^1 t^{S-1}\mathrm dt 把它拆开:记 f_j(t)=\mathbb E[t^{L_j}]L_j 的概率生成函数,由独立性

P(u\to c)=\int_0^1 f_c'(t)\prod_{j\ne c}f_j(t)\mathrm dt.

还需维护每个结点的生成函数。设选中儿子的重链长度为 L^*,则 L_u=1+L^*,故 f_u(x)=x\mathbb E[x^{L^*}]。同样用上面的积分展开,记 p_{c,\ell}=[t^\ell]f_c(t)R_c(t)=\prod_{j\ne c}f_j(t),可得

[x^{\ell+1}]f_u=\sum_{c}\ell p_{c,\ell}\int_0^1 t^{\ell-1}R_c(t)\mathrm dt,

P(u\to c) 恰是上式中第 c 个儿子对所有 \ell 的贡献之和。其中 \int_0^1 t^m\mathrm dt=\frac1{m+1},在模意义下即 m+1 的逆元。

实现上自底向上 DFS。叶子 f_u(x)=x;非叶结点先把所有儿子的 f_j 乘成 \prod f_j,再对每个 c 用多项式除法得到 R_c=\prod f_j / f_c,按上式累加出 f_u 与各 P(u\to c)。每个结点的多项式次数等于子树高度,由树形背包的配对计数,总复杂度 O(n^2)

一个常数优化:纯链(含叶子)的 f 是单项式 x^L,乘除都退化成移位。把这类儿子单独拎出来——它们对同一父结点贡献的那个积分值完全相同(链长在分子分母里恰好抵消),一次算出再乘以各自的 L 即可,无需逐个做稠密多项式运算。这样星形、菊花、链等结构从 O(n^2) 降到近 O(n)

时间复杂度为 O(n^2)

参考代码

#include <bits/stdc++.h>
using namespace std;

using ll=long long;
const int mod=998244353;
const int N=5005;
int siz[N];
ll inv[N],ans;
vector<int> G[N];
ll Pow(ll x,ll y)
{
    x%=mod;
    ll res=1;
    while(y)
    {
        if(y&1)res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void init()
{
    for(int i=1;i<N;i++)inv[i]=i==1?1:(mod-mod/i)*inv[mod%i]%mod;
}
vector<ll> dfs(int u,int fa)
{
    siz[u]=1;
    vector<vector<ll>> fc;
    vector<int> cv,mL,ms;
    for(int v:G[u])
    {
        if(v==fa)continue;
        vector<ll> fv=dfs(v,u);
        siz[u]+=siz[v];
        int nz=0,deg=0;
        for(int i=0;i<fv.size();i++)if(fv[i]){nz++;deg=i;}
        if(nz==1){mL.push_back(deg);ms.push_back(siz[v]);}
        else{fc.push_back(move(fv));cv.push_back(v);}
    }
    if(fc.empty()&&mL.empty())return {0,1};
    vector<ll> p={1};
    for(auto &f:fc)
    {
        vector<ll> g(p.size()+f.size()-1,0);
        for(int j=0;j<f.size();j++)
        {
            if(!f[j])continue;
            ll fj=f[j];
            for(int i=0;i<p.size();i++)if(p[i])g[i+j]=(g[i+j]+p[i]*fj)%mod;
        }
        p=move(g);
    }
    int dd=p.size()-1,s=0,mxd=0;
    for(int l:mL)
    {
        s+=l;
        mxd=max(mxd,l);
    }
    for(auto &f:fc)
    {
        int d=f.size()-1;
        mxd=max(mxd,d);
    }
    vector<ll> fu(mxd+2,0);
    if(!mL.empty())
    {
        ll im=0;
        for(int m=0;m<=dd;m++)im=(im+p[m]*inv[s+m])%mod;
        for(int i=0;i<mL.size();i++)
        {
            int l=mL[i];
            ll pk=(ll)l*im%mod;
            fu[l+1]=(fu[l+1]+pk)%mod;
            ans=(ans+(ll)ms[i]*((1-pk+mod)%mod))%mod;
        }
    }
    for(int c=0;c<fc.size();c++)
    {
        auto &f=fc[c];
        int dc=f.size()-1,dr=dd-dc;
        vector<int> nz;
        for(int j=0;j<=dc;j++)if(f[j])nz.push_back(j);
        vector<ll> w(p),r(dr+1,0);
        ll il=Pow(f[dc],mod-2);
        for(int i=dd;i>=dc;i--)
        {
            ll q=w[i]*il%mod;
            r[i-dc]=q;
            for(int j:nz)w[i-dc+j]=((w[i-dc+j]-q*f[j])%mod+mod)%mod;
        }
        ll pp=0;
        for(int l:nz)
        {
            if(!l)continue;
            ll t=0;
            for(int m=0;m<=dr;m++)t=(t+r[m]*inv[l+s+m])%mod;
            t=t*l%mod*f[l]%mod;
            fu[l+1]=(fu[l+1]+t)%mod;
            pp=(pp+t)%mod;
        }
        ans=(ans+(ll)siz[cv[c]]*((1-pp+mod)%mod))%mod;
    }
    return fu;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    init();
    int cas,t;
    cin>>cas>>t;
    while(t--)
    {
        int n;
        cin>>n;
        for(int i=1;i<=n;i++)G[i].clear();
        for(int i=1;i<n;i++)
        {
            int u,v;
            cin>>u>>v;
            G[u].push_back(v);
            G[v].push_back(u);
        }
        ans=0;
        dfs(1,0);
        cout<<ans<<'\n';
    }
    return 0;
}