[数学记录]Uoj#450. 【集训队作业2018】复读机

command_block

2020-05-10 10:50:59

Personal

**题意** : 对一个长度为$n$的排列进行染色,一共$k$种颜色,要求每种颜色使用的次数都是$d$的倍数,求方案数。 - $d\leq 2,k\leq 5\times10^5$ - $d=3,k\leq 1000$ $n\leq 10^9$,对 $19491001$ 取模,时限 $\texttt{1s}$。 ------------ $n$的范围以及模数告诉我们这不是一道传统的多项式工业题。 对于 $d=1$ 的情况,答案显然是 $k^n$. 考虑经典的`EGF`,写出纯色排列的生成函数。 $$F(x)=\sum\limits_{d|i}\dfrac{x^i}{i!}$$ 答案就是: $$n![x^n]F(x)^k$$ 但是难以得到$F(n)$的封闭形式,这给我们提取系数带来了障碍。 我们可以使用单位根反演 : $[d|n]=\dfrac{1}{d}\sum\limits_{i=0}^{d-1}w_d^{ni}$ $$F(x)=\sum\limits_{i=0}\dfrac{1}{d}\sum\limits_{j=0}^{d-1}w_d^{ij}\dfrac{x^i}{i!}$$ $$=\dfrac{1}{d}\sum\limits_{j=0}^{d-1}\sum\limits_{i=0}\dfrac{(w_d^jx)^i}{i!}$$ $$=\dfrac{1}{d}\sum\limits_{j=0}^{d-1}\exp(w_d^jx)$$ 所以我们要算的就是 $$n![x^n]\left(\dfrac{1}{d}\sum\limits_{j=0}^{d-1}\exp(w_d^jx)\right)^k$$ 对于 $d=2$ 的情况,答案是 $n![x^n]\left(\dfrac{e^x+e^{-x}}{2}\right)^k$. 使用二项式定理得到 $$=n![x^n]\dfrac{1}{2^k}\sum\limits_{i=0}^k\dbinom{k}{i}e^{ix}e^{-x(k-i)}$$ $$=n![x^n]\dfrac{1}{2^k}\sum\limits_{i=0}^k\dbinom{k}{i}e^{-x(k-2i)}$$ 提取系数得 $$=\dfrac{1}{2^k}\sum\limits_{i=0}^k\dbinom{k}{i}(2i-k)^n$$ 计算复杂度为$O(k\log k)$。 对于 $d=3$ 的情况 答案是 $n![x^n]\left(\dfrac{e^x+e^{w_3x}+e^{w_3^2x}}{3}\right)^k$. 直接三项式展开 $$=n![x^n]\dfrac{1}{3^k}\sum\limits_{i=0}^k\dbinom{k}{i}\sum\limits_{j=0}^i\dbinom{i}{j}e^{ix}e^{jw_3x}e^{(k-i-j)w_3^2x}$$ $$=n![x^n]\dfrac{1}{3^k}\sum\limits_{i=0}^k\dbinom{k}{i}\sum\limits_{j=0}^i\dbinom{i}{j}e^{(i+jw_3+(k-i-j)w_3^2)x}$$ 提取系数得 $$=[x^n]\dfrac{1}{3^k}\sum\limits_{i=0}^k\dbinom{k}{i}\sum\limits_{j=0}^i\dbinom{i}{j}\big(i+jw_3+(k-i-j)w_3^2\big)^n$$ 计算复杂度为$O(k^2\log k)$。 ```cpp #include<algorithm> #include<cstdio> #define ll long long #define mod 19491001 #define MaxK 500500 using namespace std; ll powM(ll a,int t=mod-2){ ll ret=1; while(t){ if (t&1)ret=ret*a%mod; a=a*a%mod;t>>=1; }return ret; } ll fac[MaxK],ifac[MaxK]; ll C(int n,int m) {return fac[n]*ifac[m]%mod*ifac[n-m]%mod;} void Init(int lim) { fac[0]=1; for (int i=1;i<=lim;i++) fac[i]=fac[i-1]*i%mod; ifac[lim]=powM(fac[lim]); for (int i=lim;i;i--) ifac[i-1]=ifac[i]*i%mod; } int n,k,d; int main() { scanf("%d%d%d",&n,&k,&d); if (d==1){ printf("%lld",powM(k,n)); return 0; }Init(k); ll ans=0; if (d==2){ for (int i=0;i<=k;i++) ans=(ans+C(k,i)*powM(mod+i+i-k,n))%mod; printf("%lld",ans*powM(2,mod-1-k)%mod); }else { ll w=663067,w2=w*w%mod; for (int i=0;i<=k;i++){ ll sum=0; for (int j=0;j<=k-i;j++) sum=(sum+ ifac[j]*ifac[k-i-j]%mod* powM( (i+w*j+w2*(mod+k-i-j))%mod ,n) )%mod; ans=(ans+fac[k]*ifac[i]%mod*sum)%mod; }printf("%lld",ans*powM(3,mod-1-k)%mod); }return 0; } ```