题解 P4491 【[HAOI2018]染色】
whyl
2020-04-28 17:09:15
***
[HAOI2018]染色
***
题意:
有一个长度为 *n* 的序列,*m* 种颜色,给定一个*S*,每种染色方案的价值是
$$
W_{出现次数等于S的颜色总数}
$$
求对所有染色方案的价值和,答案对 1004535809 取模。
数据范围 : *n* <=1e7 , *m* <= 1e5 ,*S* <=150
***
题解:
定义状态 *f[i]* 表示至少出现了 *i* 种颜色出现次数等于 *S* 的方案数, *g[i]* 表示恰好出现了*i* 种......
那么 *f* 的转移方程为 :
$$
f[i]=C{i,m} * \frac{n!}{(n-i*S)! *(S!)^i} * (m-i)^{n-S*i}
$$
根据二项式反演
$$
g[i]=\sum_{j=i}^{lim} (-1)^{j-i} C(j,i) f[j]
$$
进一步转化为
$$
g[i]*i!=\sum_{j=i}^{lim} \frac{(-1)^{j-i}}{(j-i)!}f[j]* j!
$$
直接 *NTT* 在对 *g* 除一个 *i!* ,答案就是
$$
ans=\sum_{i=1}^{lim} g[i] * w[i]
$$
***
代码:
```cpp
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){
int x=0,f=1;
char p=getchar();
while(!isdigit(p)){
if(p=='-') f=-1;
p=getchar();
}
while(isdigit(p)) x=(x<<3)+(x<<1)+(p^48),p=getchar();
return x*f;
}
const int maxn=1e7+5,maxm=4e5+5,mod=1004535809,G=3,Gi=334845270,N=1e7;
int n,m,s,f[maxm],cheng[maxn],inv[maxn],len=1,r[maxm],g[maxm],ans,W[maxm];
inline int power(int a,int b){
int ans=1;
while(b){
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
inline int C(int n,int m){
return cheng[n]*inv[m]%mod*inv[n-m]%mod;
}
inline void ntt(int *f,int len,int op){
for(int i=0;i<len;i++) if(i<r[i]) swap(f[i],f[r[i]]);
for(int i=1;i<len;i<<=1){
int wn;if(op==1) wn=power(G,(mod-1)/(i<<1));else wn=power(Gi,(mod-1)/(i<<1));
for(int j=0;j<len;j+=(i<<1)){
int w=1;
for(int k=0;k<i;k++,w=w*wn%mod){
int nx=f[j+k],ny=f[i+j+k]*w%mod;
f[j+k]=(nx+ny)%mod;
f[i+j+k]=(nx-ny+mod)%mod;
}
}
}
if(op==1) return;
int iv=power(len,mod-2);
for(int i=0;i<len;i++) f[i]=f[i]*iv%mod;
}
signed main(){
n=read();m=read();s=read();
for(int i=0;i<=m;i++) W[i]=read();
cheng[0]=1;for(int i=1;i<=N;i++) cheng[i]=cheng[i-1]*i%mod;
inv[N]=power(cheng[N],mod-2);for(int i=N-1;i>=2;i--) inv[i]=inv[i+1]*(i+1)%mod;inv[0]=inv[1]=1;
for(int i=0;i<=min(m,n/s);i++) f[i]= C(m,i)* cheng[n] %mod *inv[n-i*s] %mod *power( power( cheng[s],i ),mod-2 ) %mod * power( m-i , (n-i*s) ) %mod;
for(int i=0;i<=min(m,n/s);i++) f[i]=f[i]*cheng[i]%mod;
reverse(f,f+1+min(m,n/s));
for(int i=0;i<=min(m,n/s);i++) g[i]=inv[i];
for(int i=0;i<=min(m,n/s);i++) if(i&1) g[i]=-g[i];
while(len<=(2*min(m,n/s))) len<<=1;
for(int i=0;i<len;i++) r[i]=((r[i>>1]>>1)|((i&1)?(len>>1):0));
ntt(f,len,1);ntt(g,len,1);for(int i=0;i<len;i++) f[i]=f[i]*g[i]%mod;
ntt(f,len,-1);
reverse(f,f+1+min(m,n/s));
for(int i=0;i<=min(m,n/s);i++) f[i]=f[i]*inv[i]%mod*W[i]%mod;
for(int i=0;i<=min(m,n/s);i++) ans=(ans+f[i])%mod;
cout<<ans<<endl;
return 0;
}
```