多项式FFT

· · 个人记录

参考文献

《算法导论》 机械工业出版社(2019年6月第一版第22次印刷)

https://www.luogu.com.cn/blog/attack/solution-p3803

https://www.cnblogs.com/zhouzhendong/p/8831887.html

前置定义

我们用A(x)表示一个以x为变量的多项式,将A(x)表示为形式和:A(x)=\sum_{i=0}^{n-1} a_ix^i

本文中如无特殊说明,i=\sqrt{-1}

本文我们一般把向量当做列向量看待(向量是一个n*1的矩阵,即一个含有n个元素的列)

多项式加法

如果A(x)B(x)是次数界为n的多项式,那么它们的和也是一个次数界为n的多项式C(x),对所有的定义域的x,都有C(x)=A(x)+B(x),也就是说,若A(x)=\sum_{i=0}^{n-1} a_ix^i,且B(x)=\sum_{i=0}^{n-1} b_ix^i ,则C(x)=\sum_{i=0}^{n-1} c_ix^i。对于任意的i=0,1 \dots n-1,c_i=a_i+b_i。举个例子,如果有A(x)=6x^3+7x^2-10x+9,且B(x)=-2x^3+4x-5,那么C(x)=4x^3+7x^2-6x+4

多项式乘法

如果A(x)B(x)是次数界为n的多项式,那么它们的乘积是一个次数界为2n-1的多项式C(x)。对于所有的定义域都有C(x)=A(x)B(x),也就是说,若A(x)=\sum_{i=0}^{n-1} a_ix^i,且B(x)=\sum_{i=0}^{n-1} b_ix^i,则C(x)=\sum_{i=0}^{2n-2} c_ix^i。对于任意的i=0,1 \dots n-1,c_i=\sum_{j=0}^i a_j b_{i-j}。如果有A(x)=6x^3+7x^2-10x+9,且B(x)=-2x^3+4x-5,那么C(x)=-12x^6-14x^5+44x^4-20x^3+75x^2+86x-45。如果A是次数界为n的多项式而B是次数界为m的多项式,那么C是一个次数界为n+m-1的多项式,也可以说,C是一个次数界为n+m-1的多项式。

多项式的表达

1.系数表达

对一个次数界为n的多项式A(x)=\sum_{i=0}^{n-1}a_ix^i而言,其系数表达是由一个系数组成的向量a=(a_0,a_1,\dots,a_{n-1})

用系数表达对于多项式的某些运算时非常方便的。举个例子,对于多项式A(x)在定点x_0的求值,我们可以在O(n)的时间内完成求值运算:A(x_0)=a_0+x_0(a_1+x_0(a_2+ \dots +x_0(a_{n-2}+x_0*a_{n-1}) \dots ))

现在我们来考虑两个用系数表示的,次数界为n的多项式A(x)B(x)乘法运算。如果用上面的常规方法,时间复杂度为O(n^2)。当用系数表达的时候,求值似乎更困难。由式子c_i=\sum_{j=0}^i a_j b_{i-j}推出的系数向量c称为输入向量a和b的卷积,表示为c=a \bigotimes b

2.点值表达

一个次数界为n的多项式A(x)点值表达是一个由n个点值组成的集合\{(x_0,y_0),(x_1,y_1),\dots,(x_{n-1},y_{n-1}\},使得对于任意k=0,1,\dots,n-1,所有x_k各不相同,y_k=A(x_k)。注意一个多项式可以有很多不同的点值表达,因为可以用n各不同的点x_0,x_1,\dots,x_{n-1}构成的集合表达这种方法。

对于一个用系数表达的多项式来说,原则上计算点值是简单易行的,因为我们要做的就是选取n各不同的x_0,x_1,\dots,x_{n-1},然后对于k=0,1,\dots,n-1求出A(x_k)。我们发现求点值所需的时间复杂度是O(n^2),稍后我们会发现,如果巧妙的选取点值,其运算时间就可以变成O(n log_2 n)

求值计算的逆(从一个多项式的点值表示确定系数表示)称为插值

差值多项式的唯一性:对于任意n个点组成的集合\{(x_0,y_0),(x_1,y_1),\dots,(x_{n-1},y_{n-1}\},其中所有的x_k都不同;那么存在唯一的次数界为n的多项式A(x),满足y_k=a(x_k),k=0,1,\dots,n-1

这个定理描述了求解线性方程组的一种插值算法,但它的时间复杂度是O(n^3)。更快的方法是基于拉格朗日公式A(x)=\sum_{i=0}^{n-1}y_i \frac{\prod_{i \neq j}(x-x_j)}{\prod_{i \neq j} (x_i-x_j)}。我们可以验证等式的右端是一个次数界为n的多项式,并满足对于所有kA(x_k)=y_k。这样,我们就可以在O(n^2)的时间复杂度内计算A的所有系数。

点值加法: 如果C(x)=A(x)+B(x),则对于任意点x_k,满足C(x_k)=A(x_k)+B(x_k)。准确的说,如果A的点值表达是\{(x_0,y_0),(x_1,y_1),\dots,(x_{n-1},y_{n-1})\}B的点值表达是\{(x_0,y_0^"),(x_1,y_1^"),\dots,(x_{n-1},y_{n-1}^")\},(注意,A和B在相同的n个位置求值),则C的点值求值是\{(x_0,y_0+y_0^"),(x_1,y_1+y_1^"),\dots,(x_{n-1},y_{n-1}+y_{n-1}^")\}

点值乘法: 如果C(x)=A(x)B(x),则对于任意点x_k,满足C(x_k)=A(x_k)B(x_k)。准确的说,如果A的点值表达是\{(x_0,y_0),(x_1,y_1),\dots,(x_{2n-1},y_{2n-1})\}B的点值表达是\{(x_0,y_0^"),(x_1,y_1^"),\dots,(x_{2n-1},y_{2n-1}^")\},(注意,A和B在相同的2n个位置求值),则C的点值求值是\{(x_0,y_0y_0^"),(x_1,y_1y_1^"),\dots,(x_{2n-1},y_{2n-1}y_{2n-1}^")\}

系数形式表示的多项式快速乘法

我们能否进行基于点值形式的线性时间乘法算法来加速基于系数形式表达的多项式乘法运算,关键在于能否快速进行多项式系数形式和点值形式的转换

前面说过,巧妙的选取点值,其运算时间就可以变成O(n log_2 n)。事实上我们我们选择单位复数根作为求值点就可以满足。我们对系数向量跑DFT,得到相应的点值表达。也可以对点值执行IDFT变换,得到系数向量。

基础FFT

我们先来看FFT的主过程

(这里设log_2 n是一个整数。如果log_2 n不是整数就扩大n直到n是整数为止)

A_1(x)=a_0+a_2*{x}+a_4*{x^2}+\dots+a_{n-2}*x^{\frac{n}{2}-1},A_2(x)=a_1*x+a_3*{x}+a_5*{x^2}+ \dots+a_{n-1}*x^{\frac{n}{2}-1}

A(x)=A_1(x^2)+xA_2(x^2)

代入w_n^k (k<\frac{n}{2})

A(w_n^k)=A_1(w_n^{2k})+w_n^kA_2(w_n^{2k}) =A_1(w_{\frac{n}{2}}^{k})+w_n^kA_2(w_{\frac{n}{2}}^{k})

代入w_n^{k+\frac{n}{2}}

A(w^{k+\frac{n}{2}})=A_1(w_n^{2k+n})+w_n^{k+\frac{n}{2}}(w_n^{2k+n}) =A_1(w_n^{2k}*w_n^n)-w_n^kA_2(w_n^{2k}*w_n^n) =A_1(w_n^{2k})-w_n^kA_2(w_n^{2k})

显然这两个式子只有常数不同,所以可以一并计算

又因为计算的过程是递归实现的,所以可以分治。

根据这种方式所写的FFT代码如下:

#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=1000010;
const double pi=acos(-1);
complex<double> a[N],b[N];
ll n,m;

inline ll read(){
    ll x=0,tmp=1;
    char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') tmp=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=(x<<3)+(x<<1)+(ch^48);
        ch=getchar();
    }
    return tmp*x;
}

void FFT(complex<double> *a,ll n,ll op){
    if(!n) return;
    complex<double> a0[n],a1[n];
    for(ll i=0; i<n; i++){
        a0[i]=a[i<<1];
        a1[i]=a[i<<1|1];
    }
    FFT(a0,n>>1,op); FFT(a1,n>>1,op);
    complex<double> W(cos(pi/n),sin(pi/n)*op),w(1,0);
    for(ll i=0; i<n; i++,w*=W){
        a[i]=a0[i]+w*a1[i];
        a[i+n]=a0[i]-w*a1[i];
    }
}

int main(){
    n=read(); m=read();
    for(ll i=0; i<=n; i++) a[i]=read();
    for(ll i=0; i<=m; i++) b[i]=read();
    for(m+=n,n=1; n<=m; n<<=1);
    FFT(a,n>>1,1); FFT(b,n>>1,1);
    for(ll i=0; i<n; i++) a[i]*=b[i];
    FFT(a,n>>1,-1);
    for(ll i=0; i<=m; i++) printf("%.0lf ",fabs(a[i].real()/n));
    return 0;
}

这份代码能过loj的FFT板子(n \leq 10^5),但是在洛谷的板子(n \leq 10^6)上GG了

我们需要继续优化常数

迭代实现

我们发现我们需要求的序列是原序列下标的二进制反转。

因此我们可以用O(n)得到每一段序列的值,然后不断向上合并即可。

另外洛谷这题卡了STL的complex,complex要手写

代码(STL版)

#include<bits/stdc++.h>
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=10000010;
const double pi=acos(-1);
ll n,m,limit;
complex<double> a[N],b[N];
ll c[N];

inline ll read(){
    ll x=0,tmp=1;
    char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') tmp=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=(x<<3)+(x<<1)+(ch^48);
        ch=getchar();
    }
    return tmp*x;
}

inline void write(ll x){
    if(x<0){
        putchar('-');
        x=-x;
    }
    ll y=10,len=1;
    while(y<=x){
        y=(y<<3)+(y<<1);
        len++;
    }
    while(len--){
        y/=10;
        putchar(x/y+48);
        x%=y;
    }
}

void FFT(complex<double> *a,ll op){
    for(ll i=0; i<limit; i++){
        if(i<c[i]) swap(a[i],a[c[i]]);
    }
    for(ll mid=1; mid<limit; mid<<=1){
        complex<double> W(cos(pi/mid),op*sin(pi/mid));
        for(ll r=mid<<1,j=0; j<limit; j+=r){
            complex<double> w(1,0);
            for(ll l=0; l<mid; l++,w*=W){
                complex<double> x=a[j+l],y=w*a[j+mid+l];
                a[j+l]=x+y; a[j+mid+l]=x-y;
            }
        }
    }
}

int main(){
    n=read(); m=read();
    for(ll i=0; i<=n; i++) a[i]=read();
    for(ll i=0; i<=m; i++) b[i]=read();
    limit=1; ll l=0;
    while(limit<=n+m){
        limit<<=1;
        l++;
    }
    for(ll i=0; i<limit; i++) c[i]=(c[i>>1]>>1)|((i&1)<<(l-1));
    FFT(a,1); FFT(b,1);
    for(ll i=0; i<=limit; i++) a[i]*=b[i];
    FFT(a,-1);
    for(ll i=0; i<=n+m; i++){
        write(a[i].real()/limit+0.5);
        putchar(' ');
    }
    return 0;
}

代码(手写STL版)

#include<iostream>
#include<cstdio>
#include<cmath>
#define ll long long
using namespace std;

const ll N=10000010;
const double pi=acos(-1);
ll n,m,limit,c[N];
struct complex{
    double real,imag;
    complex(double X=0,double Y=0){real=X; imag=Y;}
}a[N],b[N];
inline complex operator +(complex a,complex b){return complex(a.real+b.real,a.imag+b.imag);}
inline complex operator -(complex a,complex b){return complex(a.real-b.real,a.imag-b.imag);}
inline complex operator *(complex a,complex b){return complex(a.real*b.real-a.imag*b.imag,a.real*b.imag+a.imag*b.real);}

inline ll read(){
    ll x=0,tmp=1;
    char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') tmp=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=(x<<3)+(x<<1)+(ch^48);
        ch=getchar();
    }
    return tmp*x;
}

inline void write(ll x){
    if(x<0){
        putchar('-');
        x=-x;
    }
    ll y=10,len=1;
    while(y<=x){
        y=(y<<3)+(y<<1);
        len++;
    }
    while(len--){
        y/=10;
        putchar(x/y+48);
        x%=y;
    }
}

void FFT(complex *a,ll op){
    for(ll i=0; i<limit; i++){
        if(i<c[i]) swap(a[i],a[c[i]]);
    }
    for(ll mid=1; mid<limit; mid<<=1){
        complex W(cos(pi/mid),op*sin(pi/mid));
        for(ll r=mid<<1,j=0; j<limit; j+=r){
            complex w(1,0);
            for(ll l=0; l<mid; l++,w=w*W){
                complex x=a[j+l],y=w*a[j+mid+l];
                a[j+l]=x+y; a[j+mid+l]=x-y;
            }
        }
    }
}

int main(){
    n=read(); m=read();
    for(ll i=0; i<=n; i++) a[i].real=read();
    for(ll i=0; i<=m; i++) b[i].real=read();
    limit=1; ll l=0;
    while(limit<=n+m){
        limit<<=1;
        l++;
    }
    for(ll i=0; i<limit; i++) c[i]=(c[i>>1]>>1)|((i&1)<<(l-1));
    FFT(a,1); FFT(b,1);
    for(ll i=0; i<=limit; i++) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(ll i=0; i<=n+m; i++){
        write(a[i].real/limit+0.5);
        putchar(' ');
    }
    return 0;
}