题解 随

· · 个人记录

T1 随

解题思路

DP+矩阵乘(快速幂)+数论

又是一道与期望无关的期望题,显然答案是 总情况/情况数(n^m)。

接下来的问题就是对于总情况的求和了。题目下面就给出了一个很好的概念:原根

求原根

再看一下mod的值,不会错了,暴力求就行。根据原根的性质,判断枚举元素的 1\sim p-2 次方有没有在\mod p意义下等于 1 。

int get_Yuan()
{
    for(int i=2;i<=p;i++)
    {
        int temp=1;
        bool vis=true;
        for(int j=1;j<p-1;j++)
        {
            temp=temp*i%p;
            if(temp==1)
            {
                vis=false;
                break;
            }
        }
        if(vis)
            return i;
    }
}

原根用途

有了原根,我们就可以把几个数的乘积换成指数级别的加法了。

最后的结果也就是k_0\times g^0+k_1\times g^1+...+k_{p-1}\times g^{p-1}

k就是每一个答案(g的若干次幂)出现的次数,计算 k_i 就是从n个元素中取m次,取出的数的次方之和等于i,可以矩阵乘加速

优化

如果模数是一质数,在计算快速幂的时候,可以直接把指数%(p-1)

我们的矩阵计算的就是指数之和,所以关于矩阵的所有模数都是p-1,这样以来矩阵的规模也就缩小到了(p-1)\times(p-1),并且可以采用矩阵快速幂

注意

整个过程中mod的值有所变化:

再求幂以及log时,算到p-2就刚刚好,至于比他大的部分,在之后的运算中如果加上就会导致结果偏大,所以直接不初始化,值为0就好了。

code

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+10,mod=1e9+7;
int n,m,p,yg,sum,base,mi[N],lg[N],s[N];
int get_Yuan()
{
    for(int i=2;i<=p;i++)
    {
        int temp=1;
        bool vis=true;
        for(int j=1;j<p-1;j++)
        {
            temp=temp*i%p;
            if(temp==1)
            {
                vis=false;
                break;
            }
        }
        if(vis)
            return i;
    }
}
int ksm(int x,int y)
{
    int ans=1;
    while(y)
    {
        if(y&1)
            ans=ans*x%mod;
        y>>=1;
        x=x*x%mod;
    }
    return ans;
}
struct jz
{
    int h[1005];
    void clear()
    {
        memset(h,0,sizeof(h));
    }
    jz operator *(const jz &a) const
    {
        jz ans;
        ans.clear();
        for(int i=0;i<p-1;i++)
            for(int j=0;j<p-1;j++)
                ans.h[(i+j)%(p-1)]=(ans.h[(i+j)%(p-1)]+h[i]*a.h[j])%mod;
        return ans;
    }
}a,answer;
jz ksm(jz x,int y)
{   
    jz ans;
    ans.clear();
    ans.h[lg[1]]=1;
    while(y)
    {
        if(y&1)
            ans=ans*x;
        y>>=1;
        x=x*x;
    }
    return ans;
}
#undef int
int main()
{
    #define int register long long
    #define ll long long
    scanf("%lld%lld%lld",&n,&m,&p);
    yg=get_Yuan();
//  cout<<yg<<endl;
    base=ksm(ksm(n,m),mod-2);
    if(p==2)
    {
        cout<<1;
        return 0;
    }
    mi[0]=1;
    for(int i=1;i<p-1;i++)
    {
        mi[i]=mi[i-1]%p*yg%p;
        lg[mi[i]]=i;
//      cout<<mi[i]<<endl;
    }
    for(int i=1;i<=n;i++)
        scanf("%lld",&s[i]);
    for(int i=1;i<=n;i++)
        a.h[lg[s[i]]]++;
    answer=ksm(a,m);
    for(int i=0;i<p-1;i++)
        sum=(sum+mi[i]*answer.h[i]%mod)%mod;
    printf("%lld",sum*base%mod);
    return 0;
}