题解:P15649 [省选联考 2026] 找寻者 / recollector

· · 题解

好难啊,LA 中已经关于前后缀背包和退背包讨论很久了,但是我觉得可以利用积分来做吧。

(暂时不确定正确性,qoj 数据过了)

更新日志:

2026/3/7 23:03 修改了由 wjwWeiwei 大佬提出的关于时间复杂度的问题,并重写了代码。

L_u 为以 u 为顶端的重链长度。这是一个随机变量。由于子树的选择是递归且独立的,我们可以通过底向上的 DP 来维护 L_u 的概率分布。

G_u(x) = \sum P(L_u = i) x^iL_u 的生成函数。 对于非叶结点 u,设其儿子为 v_1, v_2, \dots, v_k。根据题目,选择 v_j 为重儿子的概率为 \mathbb{E}\left[\frac{L_{v_j}}{\sum_{m=1}^k L_{v_m}}\right]

W_vv 成为其父亲 u 的重儿子的概率。根据期望的线性性质,所有结点到根的轻边数量期望之和,等价于每条边 (v, parent(v)) 贡献的期望之和。

一条边 (v, u) 对答案的贡献次数等于 v 的子树大小 sz_v(因为它是 sz_v 个结点到根路径上的共同边),而它成为轻边的概率是 1 - W_v。因此,答案为:\sum_{v \neq 1} sz_v \cdot (1 - W_v)

利用积分恒等式 \frac{1}{S} = \int_0^1 t^{S-1} dt,可以推导出:

W_{v_j} = \sum_a P(L_{v_j} = a) \cdot a \cdot \mathbb{E}\left[\frac{1}{a + \sum_{m \neq j} L_{v_m}}\right]

其中 \mathbb{E}\left[\frac{1}{a + S_{rem}}\right] = \int_0^1 t^{a-1} \left( \prod_{m \neq j} G_{v_m}(t) \right) dt

由于 G_{v_m}(t) 是多项式,乘积也是多项式,积分可以通过 \sum \frac{coeff_d}{a+d} 直接 O(N) 计算。

但是不难发现,这样在遇到某些度数极大的情况时,如果暴力计算前后缀的乘积,会使时间复杂度退化至 O(N^3) 的存在,这显然是不可以的,我们考虑优化。

不难发现,可以直接使用多项式除法来解决。

考虑到子节点的重链长度至少为 1,生成函数必然没有常数项,所以说我们可以对多项式进行平移,提取出它的最低非零项。

对于每个子节点 v,其多项式 G_v(x) 最低次项必然 \ge 1。我们可以找到最低非零项的次数 m_v,令 G_v(x) = x^{m_v} Q_v(x),这样 Q_v(x) 的常数项 Q_v(0) 就必定不为 0 了。

将所有平移后的多项式 Q_v(x) 像树上背包一样乘起来得到 Total_Q。利用多项式乘法的性质,这里的合并总复杂度是严格的 O(\sum sz_v^2) = O(sz_u^2)

对于特定儿子 v,除它以外其余兄弟的生成函数乘积就是 \frac{TotalQ}{Q_v}。因为 Q_v(0) \neq 0,可以直接通过递推在 O(|TotalQ| \cdot |Q_v|) = O(sz_u \cdot sz_v) 的时间内算出商多项式 C_v

得到的商多项式对应的真实次数需要补回 \sum_{w \ne v} m_w。也就是积分里体现为分母项加上对应的偏置 shift 即可。

#include<bits/stdc++.h>
using namespace std;
const int Mod=998244353,INF=5005;
namespace Math {
    long long inv[INF];
    inline void prep(){
        inv[1]=1;
        for(int i=2;i<INF;i++)inv[i]=Mod-(Mod/i)*inv[Mod%i]%Mod;
    }
    inline long long qpow(long long a,long long b){
        long long r=1;a%=Mod;
        while(b){
            if(b&1)r=r*a%Mod;
            a=a*a%Mod;
            b>>=1;
        }
        return r;
    }
}
namespace Solver {
    using namespace Math;
    vector<int> g[INF];
    vector<long long> f[INF];
    long long W[INF];
    int sz[INF];
    void dp(int u,int p){
        sz[u]=1;
        vector<int> ch;
        for(int v:g[u])if(v!=p){dp(v,u);sz[u]+=sz[v];ch.push_back(v);}
        if(ch.empty())return f[u]={0,1},void();
        int k=ch.size(),sm=0;
        vector<vector<long long>> Q(k);
        vector<int> mv(k);
        vector<long long> TQ={1};
        for(int i=0;i<k;i++){
            int v=ch[i],m=1;
            while(m<(int)f[v].size()&&!f[v][m])m++;
            mv[i]=m;sm+=m;
            Q[i].assign(f[v].size()-m,0);
            for(int j=m;j<(int)f[v].size();j++)Q[i][j-m]=f[v][j];
            vector<long long> nxt(TQ.size()+Q[i].size()-1,0);
            for(int a=0;a<(int)TQ.size();a++)if(TQ[a])for(int b=0;b<(int)Q[i].size();b++)nxt[a+b]=(nxt[a+b]+TQ[a]*Q[i][b])%Mod;
            TQ=nxt;
        }
        f[u].assign(sz[u]+1,0);
        for(int i=0;i<k;i++){
            int v=ch[i],sh=sm-mv[i],cs=TQ.size()-Q[i].size()+1;
            vector<long long> C(cs,0);
            long long iv=qpow(Q[i][0],Mod-2);
            for(int j=0;j<cs;j++){
                long long s=TQ[j];
                for(int x=1;x<=j&&x<(int)Q[i].size();x++)s=(s-Q[i][x]*C[j-x]%Mod+Mod)%Mod;
                C[j]=s*iv%Mod;
            }
            for(int a=mv[i];a<(int)f[v].size();a++){
                if(f[v][a]){
                    long long itg=0;
                    for(int d=0;d<(int)C.size();d++)if(C[d])itg=(itg+C[d]*inv[a+sh+d])%Mod;
                    long long pb=f[v][a]*a%Mod*itg%Mod;
                    W[v]=(W[v]+pb)%Mod;
                    f[u][a+1]=(f[u][a+1]+pb)%Mod;
                }   
            }
        }
    }
    inline void clr(int n){
        for(int i=1;i<=n;i++){
            g[i].clear(),f[i].clear();
            W[i]=sz[i]=0;
        }
    }
}
namespace yixing {
    inline void sol(){
        int n;cin>>n;
        Solver::clr(n);
        for(int i=1,u,v;i<n;i++){
            cin>>u>>v;
            Solver::g[u].push_back(v),Solver::g[v].push_back(u);
        }
        if (n==1)return cout<<"0\n",void();
        Solver::dp(1,0);
        long long ans=0;
        for(int i=2;i<=n;i++)ans=(ans+(Mod+1-Solver::W[i])*Solver::sz[i])%Mod;
        cout<<ans<<"\n";
    }
    inline void Main(){
        Math::prep();
        int c,t;if(!(cin>>c>>t))return;
        while(t--)sol();
    }
}
int main(){
    ios::sync_with_stdio(0);cin.tie(0),cout.tie(0);
    yixing::Main();
    return 0;
}