[数学记录]P6516 [QkOI#R1] Quark and Graph

command_block

2020-08-07 13:24:08

Personal

**题意** : 图 $G$ 有 $n$ 个点 $m$ 条边,设 $dis[u]$ 为从 $1$ 号点出发,到达该点的最短距离。 现在给出 $dis[1...n]$ 求符合条件的图的个数,保证至少为 $1$。 $n\leq 10^5,m\leq 2\times10^5$ ,设 $t_i=\sum\limits_{u}[dis[u]=i]$ ,则有 $\sum\limits_{i}t_it_{i-1}\leq 2\times 10^5$ 时限 $\texttt{3s}$。 ------------ 第一眼 : 这也能数? 看到数据范围中的奇怪约束 : ? ? ! ! 其实题目的创意还是不错的,但是这个约束明示按照 $dis$ 分层…… 首先建立最短路 DAG ,那么两端在不同层的边是 DAG 边,而连接同一层的边不是 DAG 边。 而题目中的限制显然是 DAG 边数的上界,看来复杂度和这个有关。 注意到 DAG 边和非 DAG 边的选取是独立的,我们分别考虑。 设 $G[k]$ 为选取了 $k$ 条 DAG 边的方案数,令 $G(x)$ 为其生成函数。 类似地有 $F(x)$ 为非 DAG 边的生成函数。 - DAG 边 容易发现,每层是独立的。设 $G_r(x)$ 为第 $r$ 层到下一层布置 DAG 边的生成函数。 下一层的每个点都必须至少有一条边相连,否则最短路不可能恰好为上一层加一。这些边的出发点是没有限制的。 考虑单个右侧点,可以选择连向左侧点的一个集合,但是这个集合不能为空,其生成函数为 $(x+1)^{t_r}-1$。 则 $G_r(x)=\Big((x+1)^{t_r}-1\Big)^{\small t_{r+1}}$ $G(x)$ 即为各个 $G_r(x)$ 的乘积,次数由奇怪的约束保证,不会高于 $\rm20w$。 直接分治 FFT 即可。 - 非 DAG 边 对于第 $r$ 层,可能的内部边有 $\binom{t_r}{2}$ 条。 内部边总数为 $T=\sum\limits_r \binom{t_r}{2}$。 则 $F(x)=\sum\limits_{i=0}\dbinom{T}{i}x^i$ 然而 $T$ 可能大于模数 $p$ ,但是不会大于 $p^2$ ,而 $i<p$ 是一定成立的。 使用卢卡斯定理,得 $\dbinom{T}{i}=\dbinom{T\bmod p}{i}\dbinom{\lfloor T/p\rfloor}{0}$ 显然 $\dbinom{\lfloor T/p\rfloor}{0}=1$ ,则有 $\dbinom{T}{i}=\dbinom{T\bmod p}{i}$ ,我们简单地令 $T$ 对 $p$ 取模即可。 取模之后, $T$ 仍可能很大,需要递推求组合数。 ```cpp #include<algorithm> #include<cstring> #include<cstdio> #include<vector> #define ll long long #define ull unsigned long long #define clr(f,n) memset(f,0,sizeof(int)*(n)) #define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n)) const int _G=3,mod=998244353,Maxn=1<<18|500; using namespace std; ll powM(ll a,ll t=mod-2){ ll ans=1; while(t){ if(t&1)ans=ans*a%mod; a=a*a%mod;t>>=1; }return ans; } const int invG=powM(_G); int tr[Maxn<<1],tf; void tpre(int n){ if (tf==n)return ;tf=n; for(int i=0;i<n;i++) tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0); } void NTT(int *g,bool op,int n) { tpre(n); static ull f[Maxn<<1],w[Maxn<<1];w[0]=1; for (int i=0;i<n;i++)f[i]=(((ll)mod<<5)+g[tr[i]])%mod; for(int l=1;l<n;l<<=1){ ull tG=powM(op?_G:invG,(mod-1)/(l+l)); for (int i=1;i<l;i++)w[i]=w[i-1]*tG%mod; for(int k=0;k<n;k+=l+l) for(int p=0;p<l;p++){ int tt=w[p]*f[k|l|p]%mod; f[k|l|p]=f[k|p]+mod-tt; f[k|p]+=tt; } }if (!op){ ull invn=powM(n); for(int i=0;i<n;++i) g[i]=f[i]%mod*invn%mod; }else for(int i=0;i<n;++i)g[i]=f[i]%mod; } void px(int *f,int *g,int n) {for(int i=0;i<n;++i)f[i]=1ll*f[i]*g[i]%mod;} #define Poly vector<int> Poly operator + (const Poly &A,const Poly &B) { Poly C=A;C.resize(max(A.size(),B.size())); for (int i=0;i<B.size();i++)C[i]=(C[i]+B[i])%mod; return C; } Poly operator - (const Poly &A,const Poly &B) { Poly C=A;C.resize(max(A.size(),B.size())); for (int i=0;i<B.size();i++)C[i]=(C[i]+mod-B[i])%mod; return C; } Poly operator * (const int c,const Poly &A) { Poly C;C.resize(A.size()); for (int i=0;i<A.size();i++)C[i]=1ll*c*A[i]%mod; return C; } int lim; Poly operator * (const Poly &A,const Poly &B) { static int a[Maxn<<1],b[Maxn<<1]; for (int i=0;i<A.size();i++)a[i]=A[i]; for (int i=0;i<A.size();i++)a[i]=A[i]; cpy(a,&A[0],A.size()); cpy(b,&B[0],B.size()); Poly C;C.resize(min(lim,(int)(A.size()+B.size()-1))); int n=1;for(n;n<A.size()+B.size()-1;n<<=1); NTT(a,1,n);NTT(b,1,n); px(a,b,n);NTT(a,0,n); cpy(&C[0],a,C.size()); clr(a,n);clr(b,n); return C; } Poly powP(Poly A,int k) { int n=A.size(); Poly ret;ret.resize(n);ret[0]=1; while(k){ if (k&1){ret=ret*A;ret.resize(n);} A=A*A;A.resize(n);k>>=1; }return ret; } int fac[Maxn],ifac[Maxn]; int C(int n,int m) {return 1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;} void Init(int n) { fac[0]=1; for (int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod; ifac[n]=powM(fac[n]); for (int i=n;i;i--) ifac[i-1]=1ll*ifac[i]*i%mod; } Poly getF(int t1,int t2) { Poly F;F.resize(t1+1); for (int i=1;i<=t1;i++)F[i]=C(t1,i); F.resize((F.size()-1)*t2+1); return powP(F,t2); } int c[Maxn]; Poly solve(int l,int r) { if (l==r)return getF(c[l],c[l+1]); int mid=(l+r)>>1; return solve(l,mid)*solve(mid+1,r); } Poly F,G; int n,m; int main() { scanf("%d%d",&n,&m);lim=m+1; for (int i=1,dis;i<=n;i++){ scanf("%d",&dis); c[dis]++; }Init(m+1); F=solve(0,n-2); int cnt=0; for (int i=0;i<n;i++) cnt=(cnt+1ll*c[i]*(c[i]-1)/2)%mod; G.resize(m+1); G[0]=1; for (int i=1,buf1=1;i<=m;i++){ buf1=1ll*buf1*(cnt-i+1)%mod; G[i]=1ll*buf1*ifac[i]%mod; }G=G*F; printf("%d",G[m]); return 0; } ```