题解 P5245 多项式快速幂

· · 题解

前言:

我要强烈谴责一些没有讲清楚为什么最后化简出来的系数 k 可以取模的题解,

这是十分不负责任的行为,这些多项式题目不仅仅是背个板子,走马观花地看一

些,了解模板的推导过程才是学习多项式的目的,题目意思就不讲了,已经很清楚

了,下面讲解法。

前置知识:

多项式 ln ,多项式 exp ,多项式乘法逆。

Solution:

首先知道: B(x)≡A^k(x) (mod x^n)

考虑怎么消去次数 k ,想到对数公式。

于是有: ln B(x)≡ ln A^k(x) (mod x^n)

也即: ln B(x)≡ k×ln A(x) (mod x^n)

于是,先对多项式 A(x) 求出它的 ln ,乘以 k ,就求出了 ln B(x)

值,对 ln B(x) 做一遍多项式 exp 就是 B(x) 了。

但是次数 k 真的很大,需要一边读入一边取模,下证取模仍然使得答案正确。

由泰勒展开可得, e^x=\sum{\frac{x^i}{i!}}

把上式的 x{k×lnB(x)} 替换。

于是求和式子的分母处就是 k^i×ln^iB(x)

显然取模后不变。

下面是代码,有点长。

#include<bits/stdc++.h>
#define N 400050
#define p 998244353
#define ll long long
using namespace std;
char ch[N];
ll n,k;
ll a[N],b[N];
ll bit,pre[N],len;
ll g1=3,g2=332748118,c[N],tmp[N],lx[N],inv_num[N],Z[N],ans[N];
ll ksm(ll a,ll b)
{
    ll s=1;
    while(b)
    {
        if(b&1) s=(s*a)%p;
        a=(a*a)%p;
        b=b>>1;
    }
    return s;
} 
void NTT(ll* a,ll len,ll pd)
{
    for(register ll i=0;i<len;i++) 
    {
        if(i<pre[i]) swap(a[i],a[pre[i]]);
    }
    for (register ll mid=1;mid<=len-1;mid=(mid<<1)) 
    {
        ll gn;
        if(pd==1) gn=ksm(g1,(p-1)/(mid<<1));
        if(pd==-1) gn=ksm(g2,(p-1)/(mid<<1));   
        for(register ll i=0;i<=len-1;i+=(mid<<1)) 
        {
            ll g=1,x,y;
            for(register ll j=0;j<=mid-1;j++,g=(g*gn)%p) 
            {    
                x=a[i+j];
                y=g*a[i+j+mid]%p; 
                a[i+j]=(x+y)%p;
                a[i+j+mid]=(x-y+p)%p; 
            }
        }
    }    
    if(pd==-1)
    {
        ll inv=ksm(len,p-2);
        for(ll i=0;i<len;i++)
        {
            a[i]=(a[i]*inv)%p;
        }
    }
}
void inv(ll L,ll* a,ll* b)
{
    if(L==1)
    {
        b[0]=ksm(a[0],p-2);
        return;
    }
    inv((L+1)>>1,a,b);
    bit=0;  
    while((1<<bit)<(L<<1)) 
    {
        bit++;
    }
    len=(1<<bit);
    for(register ll i=0;i<len;i++) 
    {
        pre[i]=((pre[i>>1]>>1)|((i&1)<<(bit-1)));
    }
    for(register ll i=0;i<=L-1;i++) c[i]=a[i];
    for(register ll i=L;i<len;i++) c[i]=0;
    NTT(b,len,1);
    NTT(c,len,1);
    for(register ll i=0;i<len;i++)
    {
        b[i]=(2-c[i]*b[i]%p+p)*b[i]%p;  
    }
    NTT(b,len,-1);
    for(register ll i=L;i<len;i++) b[i]=0;
}
void Add(ll L,ll *a,ll *b)
{
    bit=0;
    while((1<<bit)<(L<<1)) 
    {
        bit++;
    }
    len=(1<<bit);
    for(register ll i=0;i<len;i++) 
    {
        pre[i]=((pre[i>>1]>>1)|((i&1)<<(bit-1)));
    }
    NTT(a,len,1);
    NTT(b,len,1);
    for(register ll i=0;i<len;i++) 
    {
        b[i]=(a[i]*b[i])%p; 
    }
    NTT(b,len,-1);
}
void D(ll *a,ll *b,ll L)
{
    for(register ll i=1;i<=L-1;i++)
    {
        b[i-1]=i*a[i]%p;
    }   
    b[L-1]=0;
}
void I(ll* a,ll* b,ll L)
{
    for(register ll i=1;i<=L-1;i++)
    {
        b[i]=a[i-1]*inv_num[i]%p;   
    }   
    b[0]=0;
}
void ln(ll L,ll* a,ll* b)
{
    for(ll i=0;i<(L<<2);i++)
    {
        b[i]=0;
    }
    inv(L,a,b);
    len=1; 
    bit=0;
    while(len<(L<<1))
    {
        bit++;
        len=(len<<1);
    }
    for(ll i=0;i<=len-1;i++)
    {
        pre[i]=(pre[i>>1]>>1)|((i&1)<<(bit-1));
    }
    for(ll i=0;i<L-1;i++) 
    {
        Z[i]=a[i+1]*(i+1)%p;
    }
    for(ll i=L-1;i<=len-1;i++) 
    {
        Z[i]=0;
    }
    NTT(Z,len,1);
    NTT(b,len,1);
    for(ll i=0;i<=len-1;i++)
    {
        b[i]=b[i]*Z[i]%p;
    }     
    NTT(b,len,-1);
    for(ll i=L-1;i>0;i--)
    {
        b[i]=b[i-1]*inv_num[i]%p;
    }  
    for(ll i=L;i<len;i++) 
    {
        b[i]=0;
    }
    b[0]=0;
}
void exp(ll* a,ll* b,ll L)
{
    if(L==1)
    {
        b[0]=1;
        return;
    }
    exp(a,b,(L+1)>>1);
    ln(L,b,lx);
    bit=0;
    while((1<<bit)<(L<<1)) 
    {
        bit++;
    }
    len=(1<<bit);
    for(register ll i=0;i<len;i++) 
    {
        pre[i]=((pre[i>>1]>>1)|((i&1)<<(bit-1)));
    }
    for(register ll i=0;i<=L-1;i++)
    {
        lx[i]=a[i]>=lx[i]?a[i]-lx[i]:a[i]-lx[i]+p;
    } 
    lx[0]++;
    for(register ll i=L;i<len;i++)
    {
        lx[i]=0;
        b[i]=0;
    }   
    NTT(lx,len,1);
    NTT(b,len,1);
    for(register ll i=0;i<len;i++)
    {
        b[i]=(b[i]*lx[i])%p; 
    }
    NTT(b,len,-1);
    for(register ll i=L;i<len;i++)
    {
        b[i]=0;
    }
}
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){x=(x*10+ch-48)%p;ch=getchar();}
    return x*f;
}
int main()
{
    scanf("%lld",&n);
    k=read();
    inv_num[1]=1;
    for(register ll i=0;i<=n-1;i++)
    {
        scanf("%lld",&a[i]);
        inv_num[i+2]=p-(p/(i+2))*inv_num[p%(i+2)]%p;
    } 
    ln(n,a,b);
    for(register ll i=0;i<=n-1;i++)
    {
        b[i]=b[i]*k%p;
    }
    exp(b,ans,n);
    for(register ll i=0;i<=n-1;i++) printf("%lld ",ans[i]);
    return 0;
}