题解:P3321 [SDOI2015] 序列统计

· · 题解

传送门

题解

普通对数的启发

在实数中,我们有:

\log(a\times b)=\log a+\log b

乘法变成加法,这很方便。

意义下的对数

在模 m 的世界里,有没有类似的“对数”?

有!这就是离散对数

对于 m 质数,存在一个 g 原根,使得每个非零数 a 都可以写成:

a\equiv g^{\operatorname{ind}(a)}\pmod{m}

于是:

a\times b\equiv g^{\operatorname{ind}(a)+\operatorname{ind}(b)}\pmod{m}

乘法变成了指数加法!

原问题转化

原问题:

a_1\times a_2\times a_3\times\dots\times a_n\equiv x\pmod{m}

取离散对数:

\operatorname{ind}(a_1)+\operatorname{ind}(a_2)+\operatorname{ind}(a_3)+\dots+\operatorname{ind}(a_n)\equiv \operatorname{ind}(x)\pmod{m-1}

现在问题变成了:

1 个数

如果我们只选 1 个数:

定义数组 \operatorname{F}(k) 表示集合中有多少个数满足 \operatorname{ind}(a)=k 式子。

那么选 1 个数,指数和为 k 的方案数就是 \operatorname{F}(k) 个。

2 个数

2 个数,指数和为 s 的方案数:

\sum\limits_{i+j\equiv s}\operatorname{F}(i)\times\operatorname{F}(j)

这恰好是多项式卷积。

n 个数

n 个数,就是 \operatorname{F} 卷自己 n 次。

快速计算 \operatorname{F}^{*n} 式子

直接做 n 次卷积太慢。

可以用快速幂!

每次卷积用 NTT 加速。

最后特判 x=0 即可。

AC code:

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int MOD=1004535809;
const int G=3,invG=334845270;
const int N=8e3+5;
const int M=5e4+5;
int add(int x,int y,int mod=MOD){
    ll res=1ll*x+y;
    if(res>=mod){
        res-=mod;
    }
    return res;
}
int mul(int x,int y,int mod=MOD){
    ll res=1ll*x*y;
    if(res>=mod){
        res%=mod;
    }
    return res;
}
int subt(int x,int y,int mod=MOD){
    ll res=1ll*x-y;
    if(res<0){
        res+=mod;
    }
    return res;
}
int ksm(int x,int y,int mod=MOD){
    int res=1;
    while(y){
        if(y&1){
            res=mul(res,x,mod);
        }
        x=mul(x,x,mod);
        y>>=1;
    }
    return res;
}
int inv(int x){
    return ksm(x,MOD-2);
}
int find(int m){
    vector<int> factor;
    int phi=m-1,tmp=phi;
    for(int i=2;i*i<=tmp;i++){
        if(tmp%i==0){
            factor.push_back(i);
            while(tmp%i==0){
                tmp/=i;
            }
        }
    }
    if(tmp>1){
        factor.push_back(tmp);
    }
    for(int g=2;g<=m;g++){
        bool ok=1;
        for(int fac:factor){
            if(ksm(g,phi/fac,m)==1){
                ok=0;
                break;
            }
        }
        if(ok){
            return g;
        }
    }
    return -1;
}
int lim=1,len;
int rev[M];
void init(){
    for(int i=0;i<lim;i++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    }
}
void ntt(int *arr,int op){
    for(int i=0;i<lim;i++){
        if(i<rev[i]){
            swap(arr[i],arr[rev[i]]);
        }
    }
    for(int mid=1;mid<lim;mid<<=1){
        int nxt=mid<<1;
        int gn=ksm(op>0?G:invG,(MOD-1)/nxt);
        for(int i=0;i<lim;i+=nxt){
            int g=1;
            for(int j=0;j<mid;j++,g=mul(g,gn)){
                int x=arr[i+j],y=mul(g,arr[i+j+mid]);
                arr[i+j]=add(x,y),arr[i+j+mid]=subt(x,y);
            }
        }
    }
    if(op<0){
        int Inv=inv(lim);
        for(int i=0;i<lim;i++){
            arr[i]=mul(arr[i],Inv);
        }
    }
}
int n,m,k,s;
int S[N];
int ind[N];
int ta[M],tb[M];
void solve(int *a,int *b,int *res,int sz){
    for(int i=0;i<lim;i++){
        ta[i]=tb[i]=0;
    }
    for(int i=0;i<sz;i++){
        ta[i]=a[i],tb[i]=b[i];
        res[i]=0;
    }
    ntt(ta,1),ntt(tb,1);
    for(int i=0;i<lim;i++){
        ta[i]=mul(ta[i],tb[i]);
    }
    ntt(ta,-1);
    for(int i=0;i<lim;i++){
        res[i%sz]=add(res[i%sz],ta[i]);
    }
}
int f[M],fn[M];
void poly_ksm(int *x,int y,int *res,int sz){
    res[0]=1;
    while(y){
        if(y&1){
            int tmp[N];
            solve(res,x,tmp,sz);
            for(int i=0;i<sz;i++){
                res[i]=tmp[i];
            }
        }
        int tmp[N];
        solve(x,x,tmp,sz);
        for(int i=0;i<sz;i++){
            x[i]=tmp[i];
        }
        y>>=1;
    }
}
signed main(){
    //HAPPY!
    ios_base::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    cin>>n>>m>>k>>s;
    bool zero=0;
    for(int i=1;i<=s;i++){
        cin>>S[i];
        zero|=!S[i];
    }
    if(!k){
        cout<<(zero?subt(ksm(s,n),ksm(s-1,n)):0)<<"\n";
        return 0;
    }
    int g=find(m);
    memset(ind,-1,sizeof(ind));
    int cur=1;
    for(int i=0;i<m-1;i++){
        ind[cur]=i;
        cur=mul(cur,g,m);
    }
    int sz=m-1;
    for(int i=1;i<=s;i++){
        if(!S[i]){
            continue;
        }
        int idx=ind[S[i]];
        f[idx]=add(f[idx],1);
    }
    while(lim<=(sz<<1)){
        lim<<=1;
        len++;
    }
    init();
    poly_ksm(f,n,fn,sz);
    cout<<fn[ind[k]]<<"\n";
    return 0;
}