多项式学习笔记
多项式乘法本质上是加法卷积
FFT
下面为了方便,设
再做一次
实现:
迭代求解
发现按照偶数项放在左边,奇数项放在右边求解时,初始数组中二进制表示为
枚举
基础代码:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=300010;
const double PI=acos(-1);
int n,m;
struct Complex { double x,y; };
Complex operator + (Complex a,Complex b) { return {a.x+b.x,a.y+b.y}; }
Complex operator - (Complex a,Complex b) { return {a.x-b.x,a.y-b.y}; }
Complex operator * (Complex a,Complex b) { return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x}; }
Complex a[N],b[N];
int rev[N],bit,tot;
void FFT(Complex a[],int dir)
{
for (int i=0;i<tot;i++)
{
if (i<rev[i])
{
swap(a[i],a[rev[i]]);
}
}
for (int mid=1;mid<tot;mid<<=1)
{
//处理长度为2mid
Complex wn1={cos(PI/mid),dir*sin(PI/mid)};
for (int i=0;i<tot;i+=2*mid)
{
Complex wnk={1,0};
for (int j=0;j<mid;j++,wnk=wnk*wn1)
{
Complex x=a[i+j],y=wnk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
}
int main()
{
cin >> n >> m;
for (int i=0;i<=n;i++) cin >> a[i].x;
for (int i=0;i<=m;i++) cin >> b[i].x;
while ((1<<bit)<n+m+1) bit++;
tot=(1<<bit);
for (int s=0;s<tot;s++) rev[s]=(rev[s>>1]>>1)|((s&1)<<(bit-1)); //s翻转
FFT(a,1),FFT(b,1);
for (int i=0;i<tot;i++) a[i]=a[i]*b[i];
FFT(a,-1);
for (int i=0;i<=n+m;i++)
{
double c=a[i].x/tot+0.5;
printf("%d ",(int)c);
}
return 0;
}
NTT
在模意义下快速实现多项式乘法的一种方法
你说得对,但是原根是一个数学符号。设
m 是正整数,a 是整数, 若a 模m 的阶等于\varphi(m) ,则称a 为模m 的一个原根。假设一个数g 是P 的原根,那么g^i \bmod P 的结果两两不同,且有1<g<P, 0<i<P ,归根到底就是g^x \equiv 1 \pmod P 当且仅当
用原根代替
封装:
namespace Poly
{
typedef vector<int> poly;
const int mod=998244353,G=3,niG=332748118;
int bit,tot,rev[N];
void print(poly a)
{
for (int i=0;i<a.size();i++) cout << a[i] << " ";
cout << "\n";
}
void NTT(poly &a,int dir)
{
for (int i=0;i<tot;i++)
{
if (i<rev[i])
{
swap(a[i],a[rev[i]]);
}
}
for (int mid=1;mid<tot;mid<<=1)
{
int wn1=qmi((dir==1 ? G : niG),(mod-1)/(mid<<1),mod);
for (int i=0;i<tot;i+=(mid<<1))
{
int wnk=1;
for (int j=0;j<mid;j++,wnk=(ll)wnk*wn1%mod)
{
int x=a[i+j],y=(ll)wnk*a[i+j+mid]%mod;
a[i+j]=(x+y)%mod;
a[i+j+mid]=((ll)x-y+mod)%mod;
}
}
}
}
void DFT(poly &a)
{
NTT(a,1);
}
void IDFT(poly &a)
{
NTT(a,-1);
int nitot=qmi(tot,mod-2,mod);
for (int i=0;i<a.size();i++) a[i]=(ll)a[i]*nitot%mod;
}
poly operator + (poly a,poly b)
{
int len=max(a.size(),b.size());
a.resize(len),b.resize(len);
for (int i=0;i<len;i++) a[i]=(a[i]+b[i])%mod;
return a;
}
poly operator - (poly a,poly b)
{
int len=max(a.size(),b.size());
a.resize(len),b.resize(len);
for (int i=0;i<len;i++) a[i]=((ll)a[i]-b[i]+mod)%mod;
return a;
}
poly operator * (poly a,int b)
{
for (int i=0;i<a.size();i++) a[i]=(ll)a[i]*b%mod;
return a;
}
poly operator * (poly a,poly b)
{
int n=a.size()-1,m=b.size()-1;
bit=0;
while ((1<<bit)<n+m+1) bit++;
tot=(1<<bit);
for (int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
a.resize(tot),b.resize(tot);
DFT(a),DFT(b);
for (int i=0;i<tot;i++) a[i]=(ll)a[i]*b[i]%mod;
IDFT(a);
a.resize(n+m+1);
return a;
}
poly Inv(poly F,int n)
{
poly G2,G;
G2.resize(1);
G2[0]=qmi(F[0],mod-2,mod);
for (int k=2;;k<<=1)
{
// mod x^k
// G(x)=2G2(x)-F(x)G2(x)G2(x) (mod x^n)
poly t1=G2*G2,t2=F;
t1.resize(k),t2.resize(k);
poly t=t1*t2;
t.resize(k);
G=G2*2-t;
G.resize(k);
G2=G;
if (k>=n) break;
}
G.resize(n+1);
return G;
}
};
原题: P4841
正解: