多项式快速幂

· · 个人记录

#include<bits/stdc++.h>
#define N 200001
#define mo 998244353
using namespace std;
int n,m,len,k;
int a[N<<1],b[N<<1],c[N<<1],s[N<<1],t[N<<1],F[N<<1];
int qpow(int a,int x)
{
    int s=1;
    int t=a;
    while (x)
    {
        if (x % 2) s=(1LL*s*t) % mo;
        x>>=1;
        t=1LL*t*t%mo;
    }
    return s; 
}
void sl(int f[],int len)
{
    int j=len>>1;
    for (int i=1;i<len-1;i++)
    {
        if (i<j) swap(f[i],f[j]);
        int k=len>>1;
        while (j>=k)
        {
            j-=k;
            k>>=1;
        }
        if (j<k) j+=k;
    }
}
void ntt(int f[],int len,int on)
{
    sl(f,len);
    for (register int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mo-1)/i);
        if (on==-1) wn=qpow(wn,mo-2);
        for (register int j=0;j<len;j+=i)
        {
            int w=1;
            for (register int k=j;k<j+i/2;k++)
            {
                register int x=f[k];
                register int y=1LL*w*f[k+i/2] % mo;
                f[k]=1LL*(x+y)% mo;
                f[k+i/2]=(1LL*(x-y)% mo+mo) % mo;
                w=1LL*w*wn % mo;
            }
        }
    }
    if (on==-1)
        for (int i=0;i<len;i++) f[i]=1LL*f[i]*qpow(len,mo-2) % mo;
}
void inv(int n,int a[],int b[])
{
    if (n==1)
    {
        b[0]=qpow(a[0],mo-2);
        return;
    }
    inv((n+1)>>1,a,b);
    int len=1;
    while (len<(n<<1)) len<<=1;
    memset(c,0,sizeof(c));
    for (int i=0;i<n;i++) c[i]=a[i];
    ntt(c,len,1);ntt(b,len,1);
    for (int i=0;i<len;i++) b[i]=(2LL-1LL*b[i]*c[i] % mo+mo) % mo*b[i]% mo;
    ntt(b,len,-1);
    for (int i=n;i<len;i++) b[i]=0;
}
void Ln(int a[],int n)
{   
    inv(n,a,b);
    for (int i=1;i<n;i++) a[i-1]=1LL*a[i]*i%mo;
    int len=1;
    while (len<(n<<1)) len<<=1;
    ntt(a,len,1);ntt(b,len,1);
    for (int i=0;i<len;i++) a[i]=(1LL*a[i]*b[i]) % mo;
    ntt(a,len,-1);
    for (int i=n-1;i>=1;i--) a[i]=1LL*a[i-1]*qpow(i,mo-2) % mo;
    a[0]=0;
    memset(b,0,sizeof(b));
}
void exp(int n,int a[],int s[])
{
    if (n==1)
    {
        s[0]=1;
        return;
    }
    exp(n>>1,a,s);
    for (int i=0;i<n;i++) t[i]=s[i];
    Ln(t,n);
    t[0]=(a[0]+1-t[0]+mo) % mo;
    for (int i=1;i<n;i++) t[i]=(1LL*(a[i]-t[i])+mo)%mo;
    ntt(t,n<<1,1);ntt(s,n<<1,1);
    for (int i=0;i<len<<1;i++) s[i]=(1LL*s[i]*t[i]) % mo;
    ntt(s,n<<1,-1);
    for (int i=n;i<=len;i++) s[i]=t[i]=0;
}
int get()
{
    char ch=getchar();
    while (!isdigit(ch)) ch=getchar();
    int k=0;
    while (isdigit(ch))
    {
        k=((1LL*k<<1)+(1LL*k<<3)+ch-48)%mo;
        ch=getchar();
    }
    return k;
}
int main()
{
    scanf("%d",&n);
    k=get();
    for (int i=0;i<n;i++) scanf("%d",&a[i]);
    Ln(a,n);
    for (int i=0;i<n;i++) a[i]=(1LL*a[i]*k) % mo;
    len=1;
    while (len<=n) len<<=1;
    exp(len,a,s);
    for (int i=0;i<n;i++) printf("%d ",s[i]);
    printf("\n");
}