题解 P5282 【【模板】快速阶乘算法】

· · 题解

这道题的数据范围太小了,卡不掉 O(\sqrt n \log^2 n) 的算法。

为了方便设 v=\lfloor\sqrt n\rfloor

我们考虑将 n! 拆开,设 g_v(x)=\prod_{i=1}^v (x+i) ,则

n!=\prod_{i=v^2+1}^n i*\prod_{i=0}^{v-1}g_v(iv)

我们考虑将 g_v 通过分治乘法求出,然后再用多点插值将每个点值插出来即可得到 O(\sqrt n \log^2n) 的优秀算法。模数不对可以用 MTT 实现。

一个 \log 的算法就是用倍增维护 g 的点值。具体做法可以参见 zzq的博客 。

代码:

#include<cstdio>
#include<cstdlib>
#include<cctype>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cassert>
#include<iostream>
#define Rep(i,a,b) for(register int i=(a);i<=(b);++i)
#define Repe(i,a,b) for(register int i=(a);i>=(b);--i)
#define rep(i,a,b) for(register int i=(a);i<(b);++i)
#define pb push_back
#define mx(a,b) (a>b?a:b)
#define mn(a,b) (a<b?a:b)
typedef unsigned long long uint64;
typedef unsigned int uint32;
typedef long long ll;
using namespace std;

namespace IO
{
    const uint32 Buffsize=1<<15,Output=1<<24;
    static char Ch[Buffsize],*S=Ch,*T=Ch;
    inline char getc()
    {
        return((S==T)&&(T=(S=Ch)+fread(Ch,1,Buffsize,stdin),S==T)?0:*S++);
    }
    static char Out[Output],*nowps=Out;

    inline void flush(){fwrite(Out,1,nowps-Out,stdout);nowps=Out;}

    template<typename T>inline void read(T&x)
    {
        x=0;static char ch;T f=1;
        for(ch=getc();!isdigit(ch);ch=getc())if(ch=='-')f=-1;
        for(;isdigit(ch);ch=getc())x=x*10+(ch^48);
        x*=f;
    }

    template<typename T>inline void write(T x,char ch='\n')
    {
        if(!x)*nowps++='0';
        if(x<0)*nowps++='-',x=-x;
        static uint32 sta[111],tp;
        for(tp=0;x;x/=10)sta[++tp]=x%10;
        for(;tp;*nowps++=sta[tp--]^48);
        *nowps++=ch;
    }
}
using namespace IO;

void file()
{
#ifndef ONLINE_JUDGE
    FILE*DSD=freopen("water.in","r",stdin);
    FILE*CSC=freopen("water.out","w",stdout);
#endif
}

const int MAXN=1<<17;

static int mod;

namespace poly
{
    inline int power(int u,int v)
    {
        register int sm=1;
        for(;v;v>>=1,u=(uint64)u*u%mod)if(v&1)
            sm=(uint64)sm*u%mod;
        return sm;
    }

    static int Len,rev[MAXN];

    static struct comp
    {
        double re,im;

        comp(){}

        comp(double _r,double _i):re(_r),im(_i){}

        comp operator+(const comp&a){return comp(re+a.re,im+a.im);}

        comp operator-(const comp&a){return comp(re-a.re,im-a.im);}

        comp operator*(const comp&a)
        {return comp(re*a.re-im*a.im,re*a.im+im*a.re);}

        comp operator/(const int&a){return comp(re/a,im/a);}
    }g[19][MAXN];

    namespace MTT
    {
        const double pi=acos(-1);

        inline void predone()
        {
            Rep(i,1,17)rep(j,0,1<<i)
                g[i][j]=comp(cos(2*pi*j/pow(2,i)),sin(2*pi*j/pow(2,i)));
        }

        inline void FFT(comp*F,int tl)
        {
            rep(i,0,Len)if(i<rev[i])swap(F[rev[i]],F[i]);
            for(register int i=2,ii=1,t=1;i<=Len;i<<=1,ii<<=1,++t)
                for(register int j=0;j<Len;j+=i)rep(k,0,ii)
                {
                    comp x=F[j+k+ii]*g[t][k];
                    F[j+k+ii]=F[j+k]-x;
                    F[j+k]=F[j+k]+x;
                }
            if(tl==-1)
            {
                reverse(F+1,F+Len);
                rep(i,0,Len)F[i]=F[i]/Len;
            }
        }

        static comp Z[MAXN];

        inline void DBLFFT(comp*F,comp*G,int tl)
        {
            if(tl==1)
            {
                memcpy(Z,F,sizeof(comp)*Len);
                FFT(Z,1);
                memcpy(F,Z,sizeof(comp)*Len);
                reverse(F+1,F+Len);
                rep(i,0,Len)F[i].im=-F[i].im;
                rep(i,0,Len)G[i]=(Z[i]-F[i])/2,F[i]=(Z[i]+F[i])/2
                    ,swap(G[i].re,G[i].im),G[i].im=-G[i].im;
            }
            else
            {
                rep(i,0,Len)F[i]=comp(F[i].re-G[i].im,F[i].im+G[i].re);
                FFT(F,-1);
                rep(i,0,Len)G[i]=comp(F[i].im,0),F[i]=comp(F[i].re,0);
            }
        }
    }
    using MTT::predone;
    using MTT::DBLFFT;
    using MTT::FFT;

    inline void calrev()
    {
        int II=log(Len)/log(2)-1;
        Rep(i,1,Len-1)rev[i]=rev[i>>1]>>1|(i&1)<<II;
    }

    inline int ad(int u,int v)
    {return ((unsigned)u+v>=mod?(unsigned)u+v-mod:u+v);}

    static int X[MAXN],Y[MAXN],Iv[MAXN],mlx[MAXN];

    static comp A[MAXN],B[MAXN],C[MAXN],D[MAXN];

    inline void mul(int*F,int*G,int*H,int lenl,int lenr)
    {
        if((ll)lenl*lenr<=300)
        {
            Rep(i,0,lenl+lenr)mlx[i]=0;
            Rep(i,0,lenl)Rep(j,0,lenr)
                mlx[i+j]=(mlx[i+j]+(ll)F[i]*G[j])%mod;
            Rep(i,0,lenl+lenr)H[i]=mlx[i];
            return;
        }
        for(Len=2;Len<=lenl+lenr;Len<<=1);
        calrev();
        rep(i,0,Len)A[i]=i<=lenl?comp(F[i]>>15,F[i]&32767):comp(0,0),
            C[i]=i<=lenr?comp(G[i]>>15,G[i]&32767):comp(0,0);
        DBLFFT(A,B,1),DBLFFT(C,D,1);
        rep(i,0,Len)
        {
            register comp t=A[i];
            A[i]=B[i]*C[i]+A[i]*D[i];
            B[i]=B[i]*D[i];
            C[i]=C[i]*t;
        }
        DBLFFT(A,B,-1),FFT(C,-1);
        rep(i,0,Len)
        {
            register int lz=((((ll)floor(C[i].re+0.5))%mod<<30)%mod
                 +((ll)floor(A[i].re+0.5))%mod*32768+(ll)floor(B[i].re+0.5))%mod;
            H[i]=i<=lenl+lenr?lz:0;
        }
    }

    inline void Inv(int*F,int*G,int ln)
    {
        Iv[0]=power(F[0],mod-2);
        for(register int Ln=2;Ln>>1<=ln;Ln<<=1)
        {
            rep(i,ln+1,Ln)F[i]=0;
            rep(i,0,Ln)X[i]=F[i],Y[i]=0;
            rep(i,0,(Ln>>1))Y[i]=Iv[i];
            mul(Y,X,X,(Ln>>1)-1,Ln-1),--X[0];
            mul(X,Y,X,Ln-1,(Ln>>1)-1);
            rep(i,(Ln>>1),Ln)Iv[i]=mod-X[i];
        }
        Rep(i,0,ln)G[i]=Iv[i];
    }

    static int ExX[MAXN],ExY[MAXN];

    inline void Div(int*F,int*G,int*Q,int*R,int lenf,int leng)
    {
        Rep(i,0,lenf)ExX[i]=F[lenf-i];
        Rep(i,0,leng)ExY[i]=G[leng-i];
        Rep(i,leng+1,lenf-leng)ExY[i]=0;
        Inv(ExY,ExY,lenf-leng);
        Rep(i,lenf-leng+1,lenf)ExX[i]=0;
        Rep(i,lenf-leng+1,leng)ExY[i]=0;
        mul(ExX,ExY,ExY,lenf-leng,lenf-leng);
        Rep(i,lenf-leng+1,(lenf-leng)<<1)ExY[i]=0;
        Rep(i,0,(lenf-leng)>>1)swap(ExY[i],ExY[lenf-leng-i]);
        mul(ExY,G,ExX,lenf-leng,leng);
        assert(F[leng]==ExX[leng]);
        Rep(i,0,leng-1)R[i]=ad(F[i],mod-ExX[i]);
        Rep(i,0,lenf-leng)Q[i]=ExY[i];
    }

    namespace Extend
    {
        static int solv[18][2][MAXN];

        void calc(int*a,int l,int r,int lev,int dir)
        {
            if(l==r)
            {
                solv[lev][dir][0]=mod-a[l];
                solv[lev][dir][1]=1;return;
            }
            int md=(l+r)>>1;
            calc(a,l,md,lev+1,0),calc(a,md+1,r,lev+1,1);
            mul(solv[lev+1][0],solv[lev+1][1],solv[lev][dir],md-l+1,r-md);
        }

        static int pol[18][MAXN],Q[MAXN];

        void getnum(int*a,int*ans,int l,int r,int lev,int n)
        {
            if(n>r-l)
            {
                calc(a,l,r,0,0);
                Div(pol[lev-(lev>0)],solv[0][0],Q,pol[lev],n,r-l+1);
                n=r-l;
            }
            else if(lev){Rep(i,0,n)pol[lev][i]=pol[lev-1][i];}
            if(l==r)
            {
                ans[l]=pol[lev][0];
                return;
            }
            int md=(l+r)>>1;
            getnum(a,ans,l,md,lev+1,n);
            getnum(a,ans,md+1,r,lev+1,n);
        }
    }
}
using poly::power;
using poly::predone;
using poly::ad;
using poly::Extend::pol;
using poly::Extend::getnum;
using poly::Extend::calc;
using poly::Extend::solv;

static int n,Z[MAXN],as[MAXN];

int main()
{
    file();
    predone();
    read(n),read(mod);
    int v=sqrt(n);
    Rep(i,1,v)Z[i]=mod-i;
    calc(Z,1,v,0,0);
    memcpy(pol[0],solv[0][0],sizeof(int)*(v+1));
    Rep(i,1,v)Z[i]=(i-1)*v;
    getnum(Z,as,1,v,0,v);
    static int ans=1;
    Rep(i,1,v)ans=(ll)ans*as[i]%mod;
    Rep(i,v*v+1,n)ans=(ll)ans*i%mod;
    cout<<ans<<endl;
    return 0;
}