[数学记录]P6516 [QkOI#R1] Quark and Graph

· · 个人记录

题意 : 图 Gn 个点 m 条边,设 dis[u] 为从 1 号点出发,到达该点的最短距离。

现在给出 dis[1...n] 求符合条件的图的个数,保证至少为 1

n\leq 10^5,m\leq 2\times10^5$ ,设 $t_i=\sum\limits_{u}[dis[u]=i]$ ,则有 $\sum\limits_{i}t_it_{i-1}\leq 2\times 10^5

时限 \texttt{3s}

第一眼 : 这也能数?

看到数据范围中的奇怪约束 : ? ? ! !

其实题目的创意还是不错的,但是这个约束明示按照 dis 分层……

首先建立最短路 DAG ,那么两端在不同层的边是 DAG 边,而连接同一层的边不是 DAG 边。

而题目中的限制显然是 DAG 边数的上界,看来复杂度和这个有关。

注意到 DAG 边和非 DAG 边的选取是独立的,我们分别考虑。

G[k] 为选取了 k 条 DAG 边的方案数,令 G(x) 为其生成函数。

类似地有 F(x) 为非 DAG 边的生成函数。

容易发现,每层是独立的。设 G_r(x) 为第 r 层到下一层布置 DAG 边的生成函数。

下一层的每个点都必须至少有一条边相连,否则最短路不可能恰好为上一层加一。这些边的出发点是没有限制的。

考虑单个右侧点,可以选择连向左侧点的一个集合,但是这个集合不能为空,其生成函数为 (x+1)^{t_r}-1

G_r(x)=\Big((x+1)^{t_r}-1\Big)^{\small t_{r+1}}

直接分治 FFT 即可。 - 非 DAG 边 对于第 $r$ 层,可能的内部边有 $\binom{t_r}{2}$ 条。 内部边总数为 $T=\sum\limits_r \binom{t_r}{2}$。 则 $F(x)=\sum\limits_{i=0}\dbinom{T}{i}x^i

然而 T 可能大于模数 p ,但是不会大于 p^2 ,而 i<p 是一定成立的。

使用卢卡斯定理,得 \dbinom{T}{i}=\dbinom{T\bmod p}{i}\dbinom{\lfloor T/p\rfloor}{0}

显然 \dbinom{\lfloor T/p\rfloor}{0}=1 ,则有 \dbinom{T}{i}=\dbinom{T\bmod p}{i} ,我们简单地令 Tp 取模即可。

取模之后, T 仍可能很大,需要递推求组合数。

#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#define ll long long
#define ull unsigned long long
#define clr(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
const int _G=3,mod=998244353,Maxn=1<<18|500;
using namespace std; 
ll powM(ll a,ll t=mod-2){
  ll ans=1;
  while(t){
    if(t&1)ans=ans*a%mod;
    a=a*a%mod;t>>=1;
  }return ans;
}
const int invG=powM(_G);
int tr[Maxn<<1],tf;
void tpre(int n){
  if (tf==n)return ;tf=n;
  for(int i=0;i<n;i++)
    tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
}
void NTT(int *g,bool op,int n)
{
  tpre(n);
  static ull f[Maxn<<1],w[Maxn<<1];w[0]=1;
  for (int i=0;i<n;i++)f[i]=(((ll)mod<<5)+g[tr[i]])%mod;
  for(int l=1;l<n;l<<=1){
    ull tG=powM(op?_G:invG,(mod-1)/(l+l));
    for (int i=1;i<l;i++)w[i]=w[i-1]*tG%mod;
    for(int k=0;k<n;k+=l+l)
      for(int p=0;p<l;p++){
        int tt=w[p]*f[k|l|p]%mod;
        f[k|l|p]=f[k|p]+mod-tt;
        f[k|p]+=tt;
      }
  }if (!op){
    ull invn=powM(n);
    for(int i=0;i<n;++i)
      g[i]=f[i]%mod*invn%mod;
  }else for(int i=0;i<n;++i)g[i]=f[i]%mod;
}
void px(int *f,int *g,int n)
{for(int i=0;i<n;++i)f[i]=1ll*f[i]*g[i]%mod;}
#define Poly vector<int>
Poly operator + (const Poly &A,const Poly &B)
{
  Poly C=A;C.resize(max(A.size(),B.size()));
  for (int i=0;i<B.size();i++)C[i]=(C[i]+B[i])%mod;
  return C;
}
Poly operator - (const Poly &A,const Poly &B)
{
  Poly C=A;C.resize(max(A.size(),B.size()));
  for (int i=0;i<B.size();i++)C[i]=(C[i]+mod-B[i])%mod;
  return C;
}
Poly operator * (const int c,const Poly &A)
{
  Poly C;C.resize(A.size());
  for (int i=0;i<A.size();i++)C[i]=1ll*c*A[i]%mod;
  return C;
}
int lim;
Poly operator * (const Poly &A,const Poly &B)
{
  static int a[Maxn<<1],b[Maxn<<1];
  for (int i=0;i<A.size();i++)a[i]=A[i];
  for (int i=0;i<A.size();i++)a[i]=A[i];
  cpy(a,&A[0],A.size());
  cpy(b,&B[0],B.size());
  Poly C;C.resize(min(lim,(int)(A.size()+B.size()-1)));
  int n=1;for(n;n<A.size()+B.size()-1;n<<=1);
  NTT(a,1,n);NTT(b,1,n);
  px(a,b,n);NTT(a,0,n);
  cpy(&C[0],a,C.size());
  clr(a,n);clr(b,n);
  return C;
}
Poly powP(Poly A,int k)
{
  int n=A.size();
  Poly ret;ret.resize(n);ret[0]=1;
  while(k){
    if (k&1){ret=ret*A;ret.resize(n);}
    A=A*A;A.resize(n);k>>=1;
  }return ret;
}
int fac[Maxn],ifac[Maxn];
int C(int n,int m)
{return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
void Init(int n)
{
  fac[0]=1;
  for (int i=1;i<=n;i++)
    fac[i]=1ll*fac[i-1]*i%mod;
  ifac[n]=powM(fac[n]);
  for (int i=n;i;i--)
    ifac[i-1]=1ll*ifac[i]*i%mod;
}
Poly getF(int t1,int t2)
{
  Poly F;F.resize(t1+1);
  for (int i=1;i<=t1;i++)F[i]=C(t1,i);
  F.resize((F.size()-1)*t2+1);
  return powP(F,t2);
}
int c[Maxn];
Poly solve(int l,int r)
{
  if (l==r)return getF(c[l],c[l+1]);
  int mid=(l+r)>>1;
  return solve(l,mid)*solve(mid+1,r);
}
Poly F,G;
int n,m;
int main()
{
    scanf("%d%d",&n,&m);lim=m+1;
    for (int i=1,dis;i<=n;i++){
      scanf("%d",&dis);
      c[dis]++;
  }Init(m+1);
  F=solve(0,n-2);
  int cnt=0;
  for (int i=0;i<n;i++)
    cnt=(cnt+1ll*c[i]*(c[i]-1)/2)%mod;
  G.resize(m+1);
  G[0]=1;
  for (int i=1,buf1=1;i<=m;i++){
    buf1=1ll*buf1*(cnt-i+1)%mod;
    G[i]=1ll*buf1*ifac[i]%mod;
  }G=G*F;
    printf("%d",G[m]);
    return 0;
}