题解:P15649 [省选联考 2026] 找寻者 / recollector
lailai0916 · · 题解
题意简述
以结点
解题思路
先把贡献拆到每条边上。结点
接着求每条边的选择概率。设
还需维护每个结点的生成函数。设选中儿子的重链长度为
而
实现上自底向上 DFS。叶子
一个常数优化:纯链(含叶子)的
时间复杂度为
参考代码
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
const int mod=998244353;
const int N=5005;
int siz[N];
ll inv[N],ans;
vector<int> G[N];
ll Pow(ll x,ll y)
{
x%=mod;
ll res=1;
while(y)
{
if(y&1)res=res*x%mod;
x=x*x%mod;
y>>=1;
}
return res;
}
void init()
{
for(int i=1;i<N;i++)inv[i]=i==1?1:(mod-mod/i)*inv[mod%i]%mod;
}
vector<ll> dfs(int u,int fa)
{
siz[u]=1;
vector<vector<ll>> fc;
vector<int> cv,mL,ms;
for(int v:G[u])
{
if(v==fa)continue;
vector<ll> fv=dfs(v,u);
siz[u]+=siz[v];
int nz=0,deg=0;
for(int i=0;i<fv.size();i++)if(fv[i]){nz++;deg=i;}
if(nz==1){mL.push_back(deg);ms.push_back(siz[v]);}
else{fc.push_back(move(fv));cv.push_back(v);}
}
if(fc.empty()&&mL.empty())return {0,1};
vector<ll> p={1};
for(auto &f:fc)
{
vector<ll> g(p.size()+f.size()-1,0);
for(int j=0;j<f.size();j++)
{
if(!f[j])continue;
ll fj=f[j];
for(int i=0;i<p.size();i++)if(p[i])g[i+j]=(g[i+j]+p[i]*fj)%mod;
}
p=move(g);
}
int dd=p.size()-1,s=0,mxd=0;
for(int l:mL)
{
s+=l;
mxd=max(mxd,l);
}
for(auto &f:fc)
{
int d=f.size()-1;
mxd=max(mxd,d);
}
vector<ll> fu(mxd+2,0);
if(!mL.empty())
{
ll im=0;
for(int m=0;m<=dd;m++)im=(im+p[m]*inv[s+m])%mod;
for(int i=0;i<mL.size();i++)
{
int l=mL[i];
ll pk=(ll)l*im%mod;
fu[l+1]=(fu[l+1]+pk)%mod;
ans=(ans+(ll)ms[i]*((1-pk+mod)%mod))%mod;
}
}
for(int c=0;c<fc.size();c++)
{
auto &f=fc[c];
int dc=f.size()-1,dr=dd-dc;
vector<int> nz;
for(int j=0;j<=dc;j++)if(f[j])nz.push_back(j);
vector<ll> w(p),r(dr+1,0);
ll il=Pow(f[dc],mod-2);
for(int i=dd;i>=dc;i--)
{
ll q=w[i]*il%mod;
r[i-dc]=q;
for(int j:nz)w[i-dc+j]=((w[i-dc+j]-q*f[j])%mod+mod)%mod;
}
ll pp=0;
for(int l:nz)
{
if(!l)continue;
ll t=0;
for(int m=0;m<=dr;m++)t=(t+r[m]*inv[l+s+m])%mod;
t=t*l%mod*f[l]%mod;
fu[l+1]=(fu[l+1]+t)%mod;
pp=(pp+t)%mod;
}
ans=(ans+(ll)siz[cv[c]]*((1-pp+mod)%mod))%mod;
}
return fu;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
init();
int cas,t;
cin>>cas>>t;
while(t--)
{
int n;
cin>>n;
for(int i=1;i<=n;i++)G[i].clear();
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
ans=0;
dfs(1,0);
cout<<ans<<'\n';
}
return 0;
}