题解 P4238 【【模板】多项式求逆】
开始搞多项式了...
推导式子看是看懂了就是好奇为什么多项式点值表示之后可以支持这么多骚操作 QVQ ...
推导如下:
假设已知
那么
又因为 A 不为 0 ,那么有:
平方一下就有:
两边乘个 A ,有 B 消 B ,有:
那么
提一个 B'
然后各种递归 + NTT 乱搞,复杂度就是
//by Judge
#include<iostream>
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int invG=332748118;
const int M=2e6+3;
typedef int arr[M];
char buf[1<<21],*p1,*p2;
#ifndef Judge
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#endif
inline int read(){ int x=0,f=1; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
for(;isdigit(c);c=getchar()) x=x*10+c-48; return x*f;
} char sr[1<<21],z[21]; int Z,C=-1;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
inline void print(int x,char ch=' '){
if(C>1<<20) Ot(); if(x<0) x=-x,sr[++C]='-';
for(;z[++Z]=x%10+48,x/=10;);
for(;sr[++C]=z[Z],--Z;); sr[++C]=ch;
} int n; arr a,b,c,r;
inline int qpow(int x,int p=mod-2){ int s=1; //用来求 逆元与数论变换的 Gn
for(;p;p>>=1,x=1ll*x*x%mod)
if(p&1) s=1ll*s*x%mod; return s;
}
inline void NTT(int* a,int n,int tp){ // NTT 的板子
for(int i=0;i<n;++i)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1){
int Gn=qpow(tp?3:invG,(mod-1)/(mid<<1));
for(int j=0;j<n;j+=mid<<1)
for(int k=0,g=1,x,y;k<mid;++k,g=1ll*g*Gn%mod)
x=a[j+k],y=1ll*g*a[j+k+mid]%mod,
a[j+k]=(x+y)%mod,a[j+k+mid]=(x-y+mod)%mod;
} if(tp==1) return ; int ny=qpow(n); //将答案除去 n 才是要得到的答案
for(int i=0;i<n;++i) a[i]=1ll*a[i]*ny%mod;
}
inline void work(int* a,int* b,int n){
if(n==1) return b[0]=qpow(a[0]),void(); // n=1 时可直接求出多项式 G
work(a,b,n+1>>1); int limit=1,len=0; //向下递归得到 G_(ceil(n/2)) 的解
for(;limit<n<<1;limit<<=1) ++len;
for(int i=1;i<limit;++i) //得到 rev 数组, 用于 NTT
r[i]=(r[i>>1]>>1)|((i&1)<<len-1);
for(int i=0;i<n;++i) c[i]=a[i]; //用 a 的前 n 项来进行数论变换
for(int i=n;i<limit;++i) c[i]=0; //注意后面的 n 项要清零
NTT(c,limit,1),NTT(b,limit,1); //两个多项式先数论变换一波
for(int i=0;i<limit;++i) //然后依据公式求解
b[i]=1ll*(2-1ll*c[i]*b[i]%mod+mod)*b[i]%mod;
NTT(b,limit,-1); //变换回来
for(int i=n;i<limit;++i) b[i]=0; //注意此时求出的解为 G_(n), 后面的几项要清零
}
int main(){ n=read();
for(int i=0;i<n;++i)
a[i]=read(); work(a,b,n);
for(int i=0;i<n;++i)
print(b[i]); return Ot(),0;
}