[数学记录]Uoj#450. 【集训队作业2018】复读机
command_block
2020-05-10 10:50:59
**题意** : 对一个长度为$n$的排列进行染色,一共$k$种颜色,要求每种颜色使用的次数都是$d$的倍数,求方案数。
- $d\leq 2,k\leq 5\times10^5$
- $d=3,k\leq 1000$
$n\leq 10^9$,对 $19491001$ 取模,时限 $\texttt{1s}$。
------------
$n$的范围以及模数告诉我们这不是一道传统的多项式工业题。
对于 $d=1$ 的情况,答案显然是 $k^n$.
考虑经典的`EGF`,写出纯色排列的生成函数。
$$F(x)=\sum\limits_{d|i}\dfrac{x^i}{i!}$$
答案就是:
$$n![x^n]F(x)^k$$
但是难以得到$F(n)$的封闭形式,这给我们提取系数带来了障碍。
我们可以使用单位根反演 : $[d|n]=\dfrac{1}{d}\sum\limits_{i=0}^{d-1}w_d^{ni}$
$$F(x)=\sum\limits_{i=0}\dfrac{1}{d}\sum\limits_{j=0}^{d-1}w_d^{ij}\dfrac{x^i}{i!}$$
$$=\dfrac{1}{d}\sum\limits_{j=0}^{d-1}\sum\limits_{i=0}\dfrac{(w_d^jx)^i}{i!}$$
$$=\dfrac{1}{d}\sum\limits_{j=0}^{d-1}\exp(w_d^jx)$$
所以我们要算的就是
$$n![x^n]\left(\dfrac{1}{d}\sum\limits_{j=0}^{d-1}\exp(w_d^jx)\right)^k$$
对于 $d=2$ 的情况,答案是 $n![x^n]\left(\dfrac{e^x+e^{-x}}{2}\right)^k$.
使用二项式定理得到
$$=n![x^n]\dfrac{1}{2^k}\sum\limits_{i=0}^k\dbinom{k}{i}e^{ix}e^{-x(k-i)}$$
$$=n![x^n]\dfrac{1}{2^k}\sum\limits_{i=0}^k\dbinom{k}{i}e^{-x(k-2i)}$$
提取系数得
$$=\dfrac{1}{2^k}\sum\limits_{i=0}^k\dbinom{k}{i}(2i-k)^n$$
计算复杂度为$O(k\log k)$。
对于 $d=3$ 的情况
答案是 $n![x^n]\left(\dfrac{e^x+e^{w_3x}+e^{w_3^2x}}{3}\right)^k$.
直接三项式展开
$$=n![x^n]\dfrac{1}{3^k}\sum\limits_{i=0}^k\dbinom{k}{i}\sum\limits_{j=0}^i\dbinom{i}{j}e^{ix}e^{jw_3x}e^{(k-i-j)w_3^2x}$$
$$=n![x^n]\dfrac{1}{3^k}\sum\limits_{i=0}^k\dbinom{k}{i}\sum\limits_{j=0}^i\dbinom{i}{j}e^{(i+jw_3+(k-i-j)w_3^2)x}$$
提取系数得
$$=[x^n]\dfrac{1}{3^k}\sum\limits_{i=0}^k\dbinom{k}{i}\sum\limits_{j=0}^i\dbinom{i}{j}\big(i+jw_3+(k-i-j)w_3^2\big)^n$$
计算复杂度为$O(k^2\log k)$。
```cpp
#include<algorithm>
#include<cstdio>
#define ll long long
#define mod 19491001
#define MaxK 500500
using namespace std;
ll powM(ll a,int t=mod-2){
ll ret=1;
while(t){
if (t&1)ret=ret*a%mod;
a=a*a%mod;t>>=1;
}return ret;
}
ll fac[MaxK],ifac[MaxK];
ll C(int n,int m)
{return fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
void Init(int lim)
{
fac[0]=1;
for (int i=1;i<=lim;i++)
fac[i]=fac[i-1]*i%mod;
ifac[lim]=powM(fac[lim]);
for (int i=lim;i;i--)
ifac[i-1]=ifac[i]*i%mod;
}
int n,k,d;
int main()
{
scanf("%d%d%d",&n,&k,&d);
if (d==1){
printf("%lld",powM(k,n));
return 0;
}Init(k);
ll ans=0;
if (d==2){
for (int i=0;i<=k;i++)
ans=(ans+C(k,i)*powM(mod+i+i-k,n))%mod;
printf("%lld",ans*powM(2,mod-1-k)%mod);
}else {
ll w=663067,w2=w*w%mod;
for (int i=0;i<=k;i++){
ll sum=0;
for (int j=0;j<=k-i;j++)
sum=(sum+
ifac[j]*ifac[k-i-j]%mod*
powM( (i+w*j+w2*(mod+k-i-j))%mod ,n)
)%mod;
ans=(ans+fac[k]*ifac[i]%mod*sum)%mod;
}printf("%lld",ans*powM(3,mod-1-k)%mod);
}return 0;
}
```