[数学记录]P6516 [QkOI#R1] Quark and Graph
command_block
2020-08-07 13:24:08
**题意** : 图 $G$ 有 $n$ 个点 $m$ 条边,设 $dis[u]$ 为从 $1$ 号点出发,到达该点的最短距离。
现在给出 $dis[1...n]$ 求符合条件的图的个数,保证至少为 $1$。
$n\leq 10^5,m\leq 2\times10^5$ ,设 $t_i=\sum\limits_{u}[dis[u]=i]$ ,则有 $\sum\limits_{i}t_it_{i-1}\leq 2\times 10^5$
时限 $\texttt{3s}$。
------------
第一眼 : 这也能数?
看到数据范围中的奇怪约束 : ? ? ! !
其实题目的创意还是不错的,但是这个约束明示按照 $dis$ 分层……
首先建立最短路 DAG ,那么两端在不同层的边是 DAG 边,而连接同一层的边不是 DAG 边。
而题目中的限制显然是 DAG 边数的上界,看来复杂度和这个有关。
注意到 DAG 边和非 DAG 边的选取是独立的,我们分别考虑。
设 $G[k]$ 为选取了 $k$ 条 DAG 边的方案数,令 $G(x)$ 为其生成函数。
类似地有 $F(x)$ 为非 DAG 边的生成函数。
- DAG 边
容易发现,每层是独立的。设 $G_r(x)$ 为第 $r$ 层到下一层布置 DAG 边的生成函数。
下一层的每个点都必须至少有一条边相连,否则最短路不可能恰好为上一层加一。这些边的出发点是没有限制的。
考虑单个右侧点,可以选择连向左侧点的一个集合,但是这个集合不能为空,其生成函数为 $(x+1)^{t_r}-1$。
则 $G_r(x)=\Big((x+1)^{t_r}-1\Big)^{\small t_{r+1}}$
$G(x)$ 即为各个 $G_r(x)$ 的乘积,次数由奇怪的约束保证,不会高于 $\rm20w$。
直接分治 FFT 即可。
- 非 DAG 边
对于第 $r$ 层,可能的内部边有 $\binom{t_r}{2}$ 条。
内部边总数为 $T=\sum\limits_r \binom{t_r}{2}$。
则 $F(x)=\sum\limits_{i=0}\dbinom{T}{i}x^i$
然而 $T$ 可能大于模数 $p$ ,但是不会大于 $p^2$ ,而 $i<p$ 是一定成立的。
使用卢卡斯定理,得 $\dbinom{T}{i}=\dbinom{T\bmod p}{i}\dbinom{\lfloor T/p\rfloor}{0}$
显然 $\dbinom{\lfloor T/p\rfloor}{0}=1$ ,则有 $\dbinom{T}{i}=\dbinom{T\bmod p}{i}$ ,我们简单地令 $T$ 对 $p$ 取模即可。
取模之后, $T$ 仍可能很大,需要递推求组合数。
```cpp
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#define ll long long
#define ull unsigned long long
#define clr(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
const int _G=3,mod=998244353,Maxn=1<<18|500;
using namespace std;
ll powM(ll a,ll t=mod-2){
ll ans=1;
while(t){
if(t&1)ans=ans*a%mod;
a=a*a%mod;t>>=1;
}return ans;
}
const int invG=powM(_G);
int tr[Maxn<<1],tf;
void tpre(int n){
if (tf==n)return ;tf=n;
for(int i=0;i<n;i++)
tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
}
void NTT(int *g,bool op,int n)
{
tpre(n);
static ull f[Maxn<<1],w[Maxn<<1];w[0]=1;
for (int i=0;i<n;i++)f[i]=(((ll)mod<<5)+g[tr[i]])%mod;
for(int l=1;l<n;l<<=1){
ull tG=powM(op?_G:invG,(mod-1)/(l+l));
for (int i=1;i<l;i++)w[i]=w[i-1]*tG%mod;
for(int k=0;k<n;k+=l+l)
for(int p=0;p<l;p++){
int tt=w[p]*f[k|l|p]%mod;
f[k|l|p]=f[k|p]+mod-tt;
f[k|p]+=tt;
}
}if (!op){
ull invn=powM(n);
for(int i=0;i<n;++i)
g[i]=f[i]%mod*invn%mod;
}else for(int i=0;i<n;++i)g[i]=f[i]%mod;
}
void px(int *f,int *g,int n)
{for(int i=0;i<n;++i)f[i]=1ll*f[i]*g[i]%mod;}
#define Poly vector<int>
Poly operator + (const Poly &A,const Poly &B)
{
Poly C=A;C.resize(max(A.size(),B.size()));
for (int i=0;i<B.size();i++)C[i]=(C[i]+B[i])%mod;
return C;
}
Poly operator - (const Poly &A,const Poly &B)
{
Poly C=A;C.resize(max(A.size(),B.size()));
for (int i=0;i<B.size();i++)C[i]=(C[i]+mod-B[i])%mod;
return C;
}
Poly operator * (const int c,const Poly &A)
{
Poly C;C.resize(A.size());
for (int i=0;i<A.size();i++)C[i]=1ll*c*A[i]%mod;
return C;
}
int lim;
Poly operator * (const Poly &A,const Poly &B)
{
static int a[Maxn<<1],b[Maxn<<1];
for (int i=0;i<A.size();i++)a[i]=A[i];
for (int i=0;i<A.size();i++)a[i]=A[i];
cpy(a,&A[0],A.size());
cpy(b,&B[0],B.size());
Poly C;C.resize(min(lim,(int)(A.size()+B.size()-1)));
int n=1;for(n;n<A.size()+B.size()-1;n<<=1);
NTT(a,1,n);NTT(b,1,n);
px(a,b,n);NTT(a,0,n);
cpy(&C[0],a,C.size());
clr(a,n);clr(b,n);
return C;
}
Poly powP(Poly A,int k)
{
int n=A.size();
Poly ret;ret.resize(n);ret[0]=1;
while(k){
if (k&1){ret=ret*A;ret.resize(n);}
A=A*A;A.resize(n);k>>=1;
}return ret;
}
int fac[Maxn],ifac[Maxn];
int C(int n,int m)
{return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
void Init(int n)
{
fac[0]=1;
for (int i=1;i<=n;i++)
fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=powM(fac[n]);
for (int i=n;i;i--)
ifac[i-1]=1ll*ifac[i]*i%mod;
}
Poly getF(int t1,int t2)
{
Poly F;F.resize(t1+1);
for (int i=1;i<=t1;i++)F[i]=C(t1,i);
F.resize((F.size()-1)*t2+1);
return powP(F,t2);
}
int c[Maxn];
Poly solve(int l,int r)
{
if (l==r)return getF(c[l],c[l+1]);
int mid=(l+r)>>1;
return solve(l,mid)*solve(mid+1,r);
}
Poly F,G;
int n,m;
int main()
{
scanf("%d%d",&n,&m);lim=m+1;
for (int i=1,dis;i<=n;i++){
scanf("%d",&dis);
c[dis]++;
}Init(m+1);
F=solve(0,n-2);
int cnt=0;
for (int i=0;i<n;i++)
cnt=(cnt+1ll*c[i]*(c[i]-1)/2)%mod;
G.resize(m+1);
G[0]=1;
for (int i=1,buf1=1;i<=m;i++){
buf1=1ll*buf1*(cnt-i+1)%mod;
G[i]=1ll*buf1*ifac[i]%mod;
}G=G*F;
printf("%d",G[m]);
return 0;
}
```