多项式

· · 个人记录

傅里叶变换(FFT)学习笔记

NTT与多项式全家桶

FFT

常量与变量

函数

代码

struct Complex{
    double x,y;
    Complex(double xx=0,double yy=0){
        x=xx,y=yy;
    }
    Complex operator*(const Complex q)const{
        return Complex(x*q.x-y*q.y,x*q.y+y*q.x);
    }
    Complex operator+(const Complex q)const{
        return Complex(x+q.x,y+q.y);
    }
    Complex operator-(const Complex q)const{
        return Complex(x-q.x,y-q.y);
    }
}a[N],b[N];
struct FFT{
    const double Pi=acos(-1.0);
    int rev[N],lim,len;
    void init(int n){
        lim=1,len=0;
        while(lim<2*n)lim<<=1,len++;
        for(int i=1;i<=lim;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
    }
    void fft(Complex *a,double flag){
        for(int i=0;i<lim;i++)
            if(i<rev[i])swap(a[i],a[rev[i]]);
        for(int i=1;i<lim;i<<=1){
            Complex w1(cos(Pi/i),flag*sin(Pi/i));
            for(int j=0;j<lim;j+=(i<<1)){
                Complex w(1,0),x,y;
                for(int k=0;k<i;k++,w=w*w1){
                    x=a[j+k],y=a[i+j+k]*w;
                    a[j+k]=x+y;a[i+j+k]=x-y;
                }
            }
        }
        if(flag<0){
            for(int i=0;i<lim;i++)
                a[i].x/=lim,a[i].y/=lim;
        }
    }
}f;

NTT

其中开方运算若多项式常数项不为 1 则要用到二次剩余求解。

#define clr(f,n) memset(f,0,sizeof(ll)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(ll)*(n))
struct Poly{
    const ll G=3,mod=998244353;
    ll qmi(ll a,int b=998244351){
        ll ans=1;
        while(b){
            if(b&1)ans=ans*a%mod;
            a=a*a%mod;b>>=1;
        }
        return ans;
    }
    const ll invG=qmi(G);
    int rv[N<<1],trv,inv[N<<1];
    void init(int n){
        if(trv==n)return;trv=n;
        inv[1]=1;
        for(int i=0;i<n;i++)
            rv[i]=(rv[i>>1]>>1)|((i&1)?n>>1:0);
        for(int i=2;i<=trv;i++)
            inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
    }
    void print(ll *f,int n){
        for(int i=0;i<n;i++)printf("%lld ",f[i]);puts("");
    }
    void NTT(ll *g,bool op,int n){
        init(n);
        static ull f[N<<1],w[N<<1];
        w[0]=1;
        for(int i=0;i<n;i++)f[i]=((mod<<5)+g[rv[i]])%mod;
        for(int l=1;l<n;l<<=1){
            ull w1=qmi(op?G:invG,(mod-1)/(l<<1));
            for(int i=1;i<l;i++)w[i]=w[i-1]*w1%mod;
            for(int i=0;i<n;i+=(l<<1)){
                for(int j=0,tt;j<l;j++){
                    tt=w[j]*f[i|l|j]%mod;
                    f[i|l|j]=f[i|j]+mod-tt;
                    f[i|j]+=tt;
                }
            }
            if(l==(1<<10))
                for(int i=0;i<n;i++)f[i]%=mod;
        }
        if(!op){
            ull invn=qmi(n);
            for(int i=0;i<n;i++)
                g[i]=f[i]%mod*invn%mod;
        }
        else for(int i=0;i<n;i++)g[i]=f[i]%mod;
    }
    void pointx(ll *f,ll *g,int n){
        for(int i=0;i<n;i++)f[i]=f[i]*g[i]%mod;
    }
    void times(ll *f,ll *g,int len,int lim){
        static ll sav[N<<1];
        int n=1;for(n;n<(len<<1);n<<=1);
        clr(f+len,n-len);clr(g+len,n-len);/*ex*/
        clr(sav,n);cpy(sav,g,n);
        NTT(f,1,n);NTT(sav,1,n);
        pointx(f,sav,n);NTT(f,0,n);
        clr(f+lim,n-lim);clr(sav,n);
    }
    void invp(ll *f,int m){
        int n;for(n=1;n<m;n<<=1);
        static ll w[N<<1],r[N<<1],sav[N<<1];
        w[0]=qmi(f[0]);
        for(int len=2;len<=n;len<<=1){
            for(int i=0;i<(len>>1);i++)r[i]=2*w[i]%mod;
            cpy(sav,f,len);
            NTT(w,1,len<<1);pointx(w,w,len<<1);
            NTT(sav,1,len<<1);pointx(w,sav,len<<1);
            NTT(w,0,len<<1);clr(w+len,len);
            for(int i=0;i<len;i++)w[i]=(r[i]-w[i]+mod)%mod;
        }
        cpy(f,w,m);clr(sav,n<<1);clr(r,n<<1);clr(w,n<<1);
    }
    void dao(ll *f,int m){
        for(int i=1;i<m;i++)f[i-1]=f[i]*i%mod;
        f[m-1]=0;
    }
    void jifen(ll *f,int m){
        for(int i=m;i;i--)f[i]=f[i-1]*inv[i]%mod;
        f[0]=0;
    }
    void lnp(ll *f,int n){
        static ll f_[N<<1];
        cpy(f_,f,n);dao(f_,n);invp(f,n);
        times(f,f_,n,n-1);jifen(f,n-1);clr(f_,n);
    }
    void exp(ll *f,int n){
        static ll a[N<<1],b[N<<1];
        int len=1;for(;len<n;len<<=1);
        cpy(a,f,n);clr(f,len);b[0]=f[0]=1;
        for(int l=2;l<=len;l<<=1){
            cpy(b,f,l>>1);lnp(f,l);
            for(int i=0;i<l;i++)
                f[i]=(a[i]-f[i]+mod)%mod;
            f[0]=(f[0]+1)%mod;
            times(f,b,l,l);
        }
        clr(a,len);clr(b,len);
    }
    void sqrtp(ll *f,int n){
        static ll a[N<<1],b[N<<1],n2=qmi(2);
        int len=1;for(;len<n;len<<=1);
        cpy(a,f,n);clr(f,len);f[0]=1;
        for(int l=2;l<=len;l<<=1){
            cpy(b,f,l>>1);times(f,f,l,l);
            for(int i=0;i<l;i++)
                f[i]=(f[i]+a[i]+mod)%mod*n2%mod;
            invp(b,l);times(f,b,l,l);
        }
        clr(a,len);clr(b,len);
    }
    void qmip(ll *f,ll k1,ll k2,int n){
        ll m=0;
        for(;m<n;m++)if(f[m])break;
        if(f[m]==0)return;
        for(int i=m;i<n;i++)f[i-m]=f[i];
        ll xk=qmi(f[0],k2),nx=qmi(f[0]);
        for(int i=0;i<n-m;i++)f[i]=f[i]*nx%mod;
        lnp(f,n-m);
        for(int i=0;i<n-m;i++)f[i]=f[i]*k1%mod;
        exp(f,n-m);
        for(int i=0;i<n-m;i++)f[i]=f[i]*xk%mod;
        m=m*k1;
        for(int i=n-1;i>=m;i--)f[i]=f[i-m];
        for(int i=0;i<m&&i<n;i++)f[i]=0;
    }
    void divp(ll *f,ll *g,ll *q,ll *r,int n,int m){
        clr(q,n);clr(r,n);
        for(int i=0;i<n-m+1;i++)q[i]=f[n-1-i];
        for(int i=0;i<n-m+1&&i<m;i++)r[i]=g[m-1-i];
        invp(r,n-m+1);times(q,r,n-m+1,n-m+1);
        for(int i=0;i*2<n-m+1;i++)swap(q[i],q[n-m-i]);
        clr(g+m-1,n-m+1);clr(r,n-m+1);cpy(r,q,m-1);
        times(g,r,m-1,m-1);
        for(int i=0;i<m-1;i++)r[i]=(f[i]-g[i]+mod)%mod;
    }
}poly;