[数学记录]P5206 [WC2019] 数树
command_block · · 个人记录
这题对我有一种特殊的意义。
前年国庆,第一次听说FFT这种科技,断断续续学了一个月才会,更没听说过幂级数科技,从此就搁置了。
前(去)年冬天,年少轻狂,以菜逼的水平,苟进了WC2019。讲课期间受同宿舍以及同市大佬的指点,开始学习多项式全家桶。
比赛的时候,这道题慷慨地给出了大量暴力分,于是我Ag了(其实主要还是靠T2的乱搞)。
断断续续花了一个星期学到了EXP,因为数学太差,切完模板之后又都搁置下来。
终于在PKUSC的前夕学会了一点生成函数,其实只是一点断章取义的皮毛罢了。但是自己当时很开心,重新打开这道题端详了一阵,结果发现连题解都看不懂,表示深深Orz.
PKUSC的赛场上,两天试机都写了个NTT,结果根本没用上。后来想起,不禁为自己感到庆幸。
然后就是最后一个无忧无虑的暑假,写了一堆斯特林数之类的,自以为会了。NOI2019同步赛测试输出没删,惨打铁。
然后就是大浪淘沙的初三季,从文化课的缝隙中挤出一点时间学信息学。终于狠狠心停课备CSP,结果打出一堆降智操作,成绩不如初二。
看完了《混凝土数学》,觉得自己生成函数水平大进。可是环顾四周,许多同届甚至低一届的大佬们早就在交谈我听不懂的幂级数科技了。
不知在多少个周末晚上,找不到题目做的时候,随手打开这道题看上一看,终究还是丢弃了不敢写。
题意 : 本题是三合一题目。
已知标号对应的两棵树,如果有某条边同时在两棵树中出现,则被联通的点成为等价类,等价关系会传递。
给出参数
形式化地讲,设两棵树的边集为
- 任务0
给出
- 任务1
给出
- 任务2
对于所有可能的
所有答案均对
对所有
-
任务0 (给出
S1,S2 )
题意理解分,直接std::map送走。
-
任务1 (给出
S1 )
方便起见,我们把贡献改为
我们把
这里涉及到了集合,那就要普及一些常用的分析手段。
可见 炫酷反演魔术
对于
我们令
注意到,
这个结果相当漂亮。
-
现在咱来捣鼓
g(S) 。先观察一下函数值和什么有关。
显然,
S 会将所有点连成若干个联通块,如果要补成一棵树,则要在块间连边,块内什么情况可以不加理会。假设
n 个点的森林,分成了m 个联通块,第i 个的大小为a_i 。如果所有的
a_i=1 ,答案显然是n^{n-2} ,现在某些a_i>1 ,则表示连这个"大点"有a_i 种方案。考虑一种可能的度数序列
d ,则对应连边方案是\prod\limits_{i=1}^ma_i^{d_i} .构造
prufer序列p ,设c_i 为i 号大点的出现次数,有d_i=c_i+1 。可得
\sum\limits_{p}\prod\limits_{i=1}^ma_i^{c_i+1} =\prod\limits_{i=1}^ma_i\sum\limits_{p}\prod\limits_{i=1}^ma_i^{c_i} =\prod\limits_{i=1}^ma_i\sum\limits_{p}\prod\limits_{i=1}^{m-2}a_{p_i} 注意到
\sum\limits_{p}\prod\limits_{i=1}^{m-2}a_{p_i}=\prod\limits_{i=1}^{m-2}\sum\limits_{p}a_{p_i}=\prod\limits_{i=1}^{m-2}n=n^{m-2} =n^{m-2}\prod\limits_{i=1}^ma_i
我们带回原式: 注意
前面的一坨按下不表,后面的意思是 : 令
然后我们就能来DP一下了。
设
转移是比较显然的树上背包:
则有
转移的时候,则有
-
H[u]=H^*[u]H[v]+H^*[u]F_v(1)+H[v]F_u^*(1)
所以,我们只需要维护
-
F_u(1)=F_u^*(1)*(H[v]+F_v(1))
复杂度DP有一个优美的组合意义。
-
任务2 (啥都不给)
把上一个任务的式子借过来用:
注意到
把
好了,现在的意思是,对于一个大小为
联通块并非铁板一块,还要乘以内部生成树的方案
弄出EGF然后EXP即可。直接拉了个板子,复杂度
不得不说这个三合一挺自然的,每个任务都有自己的妙处,给出题人点赞!
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#include<set>
#define pf printf
#define ll long long
#define mod 998244353
#define MaxN 100500
using namespace std;
inline int read(){
register int X=0;
register char ch=0;
while(ch<48||ch>57)ch=getchar();
while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
return X;
}
ll powM(ll a,int t=mod-2){
ll ret=1;
while(t){
if (t&1)ret=ret*a%mod;
a=a*a%mod;t>>=1;
}return ret;
}
int n;ll y;
namespace Solver0
{
#define Pr pair<int,int>
#define mp make_pair
set<Pr> s;
void solve()
{
if (y==1){pf("1");return ;}
for (int i=1,u,v;i<n;i++){
u=read();v=read();
if (u>v)swap(u,v);
s.insert(mp(u,v));
}int cnt=0;
for (int i=1,u,v;i<n;i++){
u=read();v=read();
if (u>v)swap(u,v);
cnt+=s.count(mp(u,v));
}pf("%lld",powM(y,n-cnt));
}
};
namespace Solver1
{
#define pb push_back
vector<int> g[MaxN];
ll H[MaxN],F[MaxN],C;
void dfs(int u)
{
H[u]=C;F[u]=1;
for (int i=0,v;i<g[u].size();i++)
if (!H[v=g[u][i]]){
dfs(v);
H[u]=(H[u]*(H[v]+F[v])+H[v]*F[u])%mod;
F[u]=F[u]*(H[v]+F[v])%mod;
}
}
void solve()
{
if (y==1){pf("%lld",powM(n,n-2));return ;}
ll buf=powM(y,n);
buf=buf*powM((y=powM(y))-1,n)%mod*powM(n,mod-3)%mod;
C=powM(y-1)*n%mod;
for (int i=1,u,v;i<n;i++){
u=read();v=read();
g[u].pb(v);g[v].pb(u);
}dfs(1);
pf("%lld",H[1]*buf%mod);
}
};
namespace Solver2
{
#define clr(f,n) memset(f,0,sizeof(ll)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(ll)*(n))
#define Maxn 135000
#define G 3
int tr[Maxn<<1],tf;
const ll invG=powM(G);
void tpre(int n){
if (tf==n)return;tf=n;
for(int i=0;i<n;i++)
tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
}
void NTT(ll *f,bool op,int n)
{
tpre(n);
for (int i=0;i<n;i++)
if (i<tr[i])swap(f[i],f[tr[i]]);
for(int p=2;p<=n;p<<=1){
int len=p/2;
ll tG=powM(op ? G:invG,(mod-1)/p);
for(int k=0;k<n;k+=p){
ll buf=1;
for(int l=k;l<k+len;l++){
int tt=buf*f[len+l]%mod;
f[len+l]=f[l]-tt;
if (f[len+l]<0)f[len+l]+=mod;
f[l]+=tt;
if (f[l]>=mod)f[l]-=mod;
buf=buf*tG%mod;
}
}
}if (!op){
ll invn=powM(n);
for(int i=0;i<n;++i)
f[i]=f[i]*invn%mod;
}
}
void px(ll *f,ll *g,int n)
{for(int i=0;i<n;++i)f[i]=f[i]*g[i]%mod;}
ll _g1[Maxn<<1];
#define sav _g1
void times(ll *f,ll *g,int len,int lim)
{
int n=1;for(n;n<len+len;n<<=1);
cpy(sav,g,n);
for(int i=len;i<n;i++)sav[i]=0;
NTT(f,1,n);NTT(sav,1,n);
px(f,sav,n);NTT(f,0,n);
for(int i=lim;i<n;++i)f[i]=0;
clr(sav,n);
}
ll _w2[Maxn<<1],_r2[Maxn<<1];
void inv(ll *f,int m)
{
int n;for (n=1;n<m;n<<=1);
#define w _w2
#define r _r2
w[0]=powM(f[0]);
for (int len=2;len<=n;len<<=1){
for (int i=0;i<(len>>1);i++)
r[i]=(w[i]<<1)%mod;
memcpy(sav,f,sizeof(ll)*len);
NTT(w,1,len<<1);px(w,w,len<<1);
NTT(sav,1,len<<1);px(w,sav,len<<1);
NTT(w,0,len<<1);clr(w+len,len);
for (int i=0;i<len;i++)
w[i]=(r[i]-w[i]+mod)%mod;
}cpy(f,w,m);clr(sav,n+n);clr(w,n+n);clr(r,n+n);
#undef w
#undef r
}
#undef sav
void dao(ll *f,int m){
for (int i=1;i<m;i++)
f[i-1]=f[i]*i%mod;
f[m-1]=0;
}
void jifen(ll *f,int m){
for (int i=m;i;i--)
f[i]=f[i-1]*powM(i)%mod;
f[0]=0;
}
ll _s3[Maxn<<1];
void lnp(ll *f,int m)
{
#define g _s3
cpy(g,f,m);
inv(g,m);dao(f,m);
times(f,g,m,m);jifen(f,m-1);
clr(g,m);
#undef g
}
ll _xp[Maxn<<1],_xp2[Maxn<<1];
void exp(ll *f,int m)
{
#define s _xp
#define s2 _xp2
int n=1;for(;n<m;n<<=1);
s2[0]=1;
for (int len=2;len<=n;len<<=1){
cpy(s,s2,len>>1);lnp(s,len);
for (int i=0;i<len;i++)
s[i]=(f[i]-s[i]+mod)%mod;
s[0]=(s[0]+1)%mod;
times(s2,s,len,len);
}cpy(f,s2,m);clr(s,n+n);clr(s2,n+n);
#undef s
#undef s2
}
ll fac[MaxN],ifac[MaxN];
void Init()
{
fac[0]=1;
for (int i=1;i<=n;i++)
fac[i]=fac[i-1]*i%mod;
ifac[n]=powM(fac[n]);
for (int i=n;i;i--)
ifac[i-1]=ifac[i]*i%mod;
}
ll F[Maxn<<1];
void solve()
{
if (y==1){pf("%lld",powM(n,2*(n-2)));return ;}
ll buf=powM(n,mod-5)*powM(y,n)%mod,
xb=1ll*n*n%mod*powM((y=powM(y))-1)%mod;
buf=buf*powM(y-1,n)%mod;
Init();
for (int i=1;i<=n;i++)
F[i]=xb*powM(i,i)%mod*ifac[i]%mod;
exp(F,n+1);
pf("%lld",F[n]*fac[n]%mod*buf%mod);
}
};
int main()
{
n=read();y=read();
int op=read();
if (op==0)Solver0::solve();
if (op==1)Solver1::solve();
if (op==2)Solver2::solve();
return 0;
}