多项式FFT
参考文献
《算法导论》 机械工业出版社(2019年6月第一版第22次印刷)
https://www.luogu.com.cn/blog/attack/solution-p3803
https://www.cnblogs.com/zhouzhendong/p/8831887.html
前置定义
我们用
本文中如无特殊说明,
本文我们一般把向量当做列向量看待(向量是一个
多项式加法
如果
多项式乘法
如果
多项式的表达
1.系数表达
对一个次数界为n的多项式
用系数表达对于多项式的某些运算时非常方便的。举个例子,对于多项式
现在我们来考虑两个用系数表示的,次数界为
2.点值表达
一个次数界为n的多项式
对于一个用系数表达的多项式来说,原则上计算点值是简单易行的,因为我们要做的就是选取
求值计算的逆(从一个多项式的点值表示确定系数表示)称为插值。
差值多项式的唯一性:对于任意
这个定理描述了求解线性方程组的一种插值算法,但它的时间复杂度是
点值加法: 如果
点值乘法: 如果
系数形式表示的多项式快速乘法
我们能否进行基于点值形式的线性时间乘法算法来加速基于系数形式表达的多项式乘法运算,关键在于能否快速进行多项式系数形式和点值形式的转换。
前面说过,巧妙的选取点值,其运算时间就可以变成
基础FFT
我们先来看FFT的主过程
(这里设
设
代入
代入
显然这两个式子只有常数不同,所以可以一并计算
又因为计算的过程是递归实现的,所以可以分治。
根据这种方式所写的FFT代码如下:
#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const ll N=1000010;
const double pi=acos(-1);
complex<double> a[N],b[N];
ll n,m;
inline ll read(){
ll x=0,tmp=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') tmp=-1;
ch=getchar();
}
while(isdigit(ch)){
x=(x<<3)+(x<<1)+(ch^48);
ch=getchar();
}
return tmp*x;
}
void FFT(complex<double> *a,ll n,ll op){
if(!n) return;
complex<double> a0[n],a1[n];
for(ll i=0; i<n; i++){
a0[i]=a[i<<1];
a1[i]=a[i<<1|1];
}
FFT(a0,n>>1,op); FFT(a1,n>>1,op);
complex<double> W(cos(pi/n),sin(pi/n)*op),w(1,0);
for(ll i=0; i<n; i++,w*=W){
a[i]=a0[i]+w*a1[i];
a[i+n]=a0[i]-w*a1[i];
}
}
int main(){
n=read(); m=read();
for(ll i=0; i<=n; i++) a[i]=read();
for(ll i=0; i<=m; i++) b[i]=read();
for(m+=n,n=1; n<=m; n<<=1);
FFT(a,n>>1,1); FFT(b,n>>1,1);
for(ll i=0; i<n; i++) a[i]*=b[i];
FFT(a,n>>1,-1);
for(ll i=0; i<=m; i++) printf("%.0lf ",fabs(a[i].real()/n));
return 0;
}
这份代码能过loj的FFT板子(
我们需要继续优化常数
迭代实现
我们发现我们需要求的序列是原序列下标的二进制反转。
因此我们可以用
另外洛谷这题卡了STL的complex,complex要手写
代码(STL版)
#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const ll N=10000010;
const double pi=acos(-1);
ll n,m,limit;
complex<double> a[N],b[N];
ll c[N];
inline ll read(){
ll x=0,tmp=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') tmp=-1;
ch=getchar();
}
while(isdigit(ch)){
x=(x<<3)+(x<<1)+(ch^48);
ch=getchar();
}
return tmp*x;
}
inline void write(ll x){
if(x<0){
putchar('-');
x=-x;
}
ll y=10,len=1;
while(y<=x){
y=(y<<3)+(y<<1);
len++;
}
while(len--){
y/=10;
putchar(x/y+48);
x%=y;
}
}
void FFT(complex<double> *a,ll op){
for(ll i=0; i<limit; i++){
if(i<c[i]) swap(a[i],a[c[i]]);
}
for(ll mid=1; mid<limit; mid<<=1){
complex<double> W(cos(pi/mid),op*sin(pi/mid));
for(ll r=mid<<1,j=0; j<limit; j+=r){
complex<double> w(1,0);
for(ll l=0; l<mid; l++,w*=W){
complex<double> x=a[j+l],y=w*a[j+mid+l];
a[j+l]=x+y; a[j+mid+l]=x-y;
}
}
}
}
int main(){
n=read(); m=read();
for(ll i=0; i<=n; i++) a[i]=read();
for(ll i=0; i<=m; i++) b[i]=read();
limit=1; ll l=0;
while(limit<=n+m){
limit<<=1;
l++;
}
for(ll i=0; i<limit; i++) c[i]=(c[i>>1]>>1)|((i&1)<<(l-1));
FFT(a,1); FFT(b,1);
for(ll i=0; i<=limit; i++) a[i]*=b[i];
FFT(a,-1);
for(ll i=0; i<=n+m; i++){
write(a[i].real()/limit+0.5);
putchar(' ');
}
return 0;
}
代码(手写STL版)
#include<iostream>
#include<cstdio>
#include<cmath>
#define ll long long
using namespace std;
const ll N=10000010;
const double pi=acos(-1);
ll n,m,limit,c[N];
struct complex{
double real,imag;
complex(double X=0,double Y=0){real=X; imag=Y;}
}a[N],b[N];
inline complex operator +(complex a,complex b){return complex(a.real+b.real,a.imag+b.imag);}
inline complex operator -(complex a,complex b){return complex(a.real-b.real,a.imag-b.imag);}
inline complex operator *(complex a,complex b){return complex(a.real*b.real-a.imag*b.imag,a.real*b.imag+a.imag*b.real);}
inline ll read(){
ll x=0,tmp=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') tmp=-1;
ch=getchar();
}
while(isdigit(ch)){
x=(x<<3)+(x<<1)+(ch^48);
ch=getchar();
}
return tmp*x;
}
inline void write(ll x){
if(x<0){
putchar('-');
x=-x;
}
ll y=10,len=1;
while(y<=x){
y=(y<<3)+(y<<1);
len++;
}
while(len--){
y/=10;
putchar(x/y+48);
x%=y;
}
}
void FFT(complex *a,ll op){
for(ll i=0; i<limit; i++){
if(i<c[i]) swap(a[i],a[c[i]]);
}
for(ll mid=1; mid<limit; mid<<=1){
complex W(cos(pi/mid),op*sin(pi/mid));
for(ll r=mid<<1,j=0; j<limit; j+=r){
complex w(1,0);
for(ll l=0; l<mid; l++,w=w*W){
complex x=a[j+l],y=w*a[j+mid+l];
a[j+l]=x+y; a[j+mid+l]=x-y;
}
}
}
}
int main(){
n=read(); m=read();
for(ll i=0; i<=n; i++) a[i].real=read();
for(ll i=0; i<=m; i++) b[i].real=read();
limit=1; ll l=0;
while(limit<=n+m){
limit<<=1;
l++;
}
for(ll i=0; i<limit; i++) c[i]=(c[i>>1]>>1)|((i&1)<<(l-1));
FFT(a,1); FFT(b,1);
for(ll i=0; i<=limit; i++) a[i]=a[i]*b[i];
FFT(a,-1);
for(ll i=0; i<=n+m; i++){
write(a[i].real/limit+0.5);
putchar(' ');
}
return 0;
}