FFT学习笔记

万弘

2020-01-20 10:14:47

Personal

## FFT学习笔记 FFT,即快速傅里叶变换,用于在$O(nlogn)$的时间内计算两个多项式的卷积。显然暴力计算是$O(n^2)$的。 前置技能:复数基本运算(至少你要知道,复数可以用复平面上的向量表示),多项式基础。 众所周知,多项式有两种表示方法: 1. 系数表示,即$f(x)=\sum_{i=0}^na_ix^i$,这也是最常用的方法 2. 点值表示,将$n$互不相同的$x$带入系数表示,会得到$n$个相应的$y$,因此得到$n$个点$(x_1,y_1),(x_2,y_2)...(x_n,y_n)$.与三点定一抛物线类似,这$n$个点可以唯一确定一个$n-1$次多项式. 因为$(f\times g)(x)=f(x)\times g(x)$,所以若已知两个多项式的点值表示,我们可以直接$O(n)$乘起来,得到其卷积的点值表示。 那么回到原问题,我们已知两个系数表示的多项式,如何在低于暴力的时间复杂度内求出其卷积的系数表示? 科学家们给出了方法:将给定的两个系数表示转化为两个点值表示(DFT,离散傅里叶变换),用上面的方法直接乘,得到卷积的点值表示,再转化为卷积的系数表示(IDFT,逆离散傅里叶变换) ## DFT 给一个$n$次(默认$n$是2的次幂,不足高位补0即可)多项式,如何求出其点值表示? $$f(x)=\sum_{i=0}^{n-1}a_ix^i$$ 按指数的奇偶性分类, $$f(x)=(a_0+a_2x^2+a_4x^4+..+a_{n-2}x^{n-2})+(a_1x+a_3x^3+..+a_{n-1}x^{n-1}$$ 记$A1(x)=a_0+a_2x+a_4x^2+..a_{n-2}x^{n/2-1},A2(x)=a_1+a_3x+..+a_{n-1}x^{n/2-1}$ 则有 $$f(x)=A1(x^2)+xA2(x^2)$$ 为什么要这样做?可以发现,$A1,A2$的次数都是$n/2-1$次,我们缩小了问题的规模,递归地求出他们的点值式,并拼起来即可得到$f(x)$。还有两个问题:如何取点?如何拼? 引入**单位根** 引一段attack大佬的话:![](https://cdn.luogu.com.cn/upload/image_hosting/ayevrzgq.png) 单位根的有用性质: 1:$w_n^k=w_{2n}^{2k}$,因为$w_{2n}^{2k}=\cos 2k\times \frac{2\pi}{2n}+i\sin 2k\times \frac{2\pi}{2n}=w_n^k$ 2:$w_0^0=w_n^n=1$ 3:$w_n^{n/2}=-1$.你看上面那个圆,分成上部和下部,一半就是-1啊。 好我们回到原题。 $$f(x)=A1(x^2)+xA2(x^2)$$ 令$x=w_n^k(k<n/2),f(x)=A1(w_n^{2k})+w_n^kA2(w_n^{2k})$ 令$x=w_n^{k+n/2},f(x)=A1(w_n^{2k+n})+w_n^{k+n/2}A2(w_n^{2k+n})=A1(w_n^{2k})-w_n^kA2(w_n^{2k})$ 可以发现,两式仅有一个常数不同,因此知一式即可$O(1)$求二式 由于$A1,A2$是$n/2-1$次的,将$k\in [0,n/2-1]$带入,即可$O(n)$完成“拼”的过程。 所以DFT的时间复杂度$T(n)=2T(n/2)+O(n)=O(nlogn)$ 求得点值表示后,我们直接$O(n)$把$f(x),g(x)$的点值表示乘起来即可得到其卷积的点值表示 ## IDFT 给定一个点值表示,求其系数表示。 记$f(x)=\sum_{i=0}^{n-1}a_ix^i$在$(w_n^0,w_n^1..w_n^n)$处点值表示为$(y_0,y_1,y_2..y_{n-1})$ 设向量$(c_0,c_1,c_2...c_{n-1})$满足$c_k=\sum_{i=0}^{n-1}y_i(w_n^{-k})^i$ 推一波式子: $$c_k=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_jw_n^{ij}(w_n^{-k})^i=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_jw_n^{ij}w_n^{-ik}$$ $$=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(w_n^{j-k})^i$$ 设$g(x)=\sum_{i=0}^{n-1}x^i$。令$x=w_n^k$,则 $$g(w_n^k)=1+w_n^k+w_n^2k+..w_n^{(n-1)k}$$ 当公比$w_n^k\ne1,$等比数列求和可得 $$g(w_n^k)=\frac{w_n^kw_n^{(n-1)k}-1}{w_n^k-1}=\frac{1-1}{w_n^k-1}=0$$ 若$w_n^k=1,k=0$或$k=n$。所以$g(x)=\sum_{i=0}^{n-1}x^i=n$ 带入 $$c_k=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(w_n^{j-k})^i$$ $$=na_k$$ $$\therefore a_k=\frac{c_k}{n}$$ 求出$c_k$即可。 小结:FFT,就是讲两个系数表示转化为点值表示,直接乘,再转化为系数表示。虽然绕了一大圈,但时间复杂度却是$\Theta (nlogn)$ ## 代码实现 听起来这么nb,该如何实现呢? 由上面的思路,写出的递归版代码是这样的: ```cpp complex a[MAXN],b[MAXN];//complex是自定义的复数类,.a为实部,.b为虚部 void FFT(complex* a,ll len,ll type)//正在处理的式,长度,type=1表示DFT,type=-1表示IDFT(因为他们只有一个系数不同) { if(len==1)return; len>>=1;//现在的len是划分成区间的长度 complex a0[len],a1[len];//左区间,右区间 for(ll k=0;k<len;++k) a0[k]=a[k<<1],a1[k]=a[k<<1|1]; FFT(a0,len,type);FFT(a1,len,type);//分治 complex wn(cos(M_PI/len),sin(M_PI/len)*type),w(1,0);//wn1,wn0 for(ll k=0;k<len;++k,w=w*wn)//wnk { complex tmp=w*a1[k]; a[k]=a0[k]+tmp,a[k+len]=a0[k]-tmp;//统计 } } ``` 嗯思路清晰!结果,Fast Fast TLE! 唔..常数太大了 考虑用迭代实现FFT. ![](https://cdn.luogu.com.cn/upload/image_hosting/lx18kv5c.png) 可以发现,每个位置的变化就是将二进制反转了。那显然我们就可以$\Theta(nlogn)$预处理出位置的变化(还可以线性递推,详见代码),然后,将信息自下而上合并,即可迭代实现FFT,常数更小。 最后是迭代实现的FFT完整代码,时间复杂度$\Theta(nlogn)$ ```cpp /**********/省略快读 #define MAXN 4000011 const double PI=M_PI; struct complex//复数类 { double a,b; complex(double aa=0,double bb=0) { a=aa,b=bb; } complex operator +(const complex& t) { return complex(a+t.a,b+t.b); } complex operator -(const complex& t) { return complex(a-t.a,b-t.b); } complex operator *(const complex& t) { return complex(a*t.a-b*t.b,a*t.b+b*t.a); } }; complex a[MAXN],b[MAXN]; double cosp[MAXN],sinp[MAXN];//预处理三角函数,sinp[i]=sin(PI/i) ll status[MAXN];//最终状态 void FFT(complex* a,ll len,ll type)//迭代fft.原式,长度,type=1表示DFT,type=-1表示IDFT { if(len==1)return; for(ll i=0;i<len;++i) if(status[i]>i)std::swap(a[i],a[status[i]]);//交换 for(ll cur=1;cur<len;cur<<=1)//现在的长度 { complex wn(cosp[cur],sinp[cur]*type);//wn1 for(ll i=cur<<1,j=0;j<len;j+=i)//j是当前位置,i是长度的两倍 { complex w(1,0);//wn0 for(ll k=0;k<cur;++k,w=w*wn) { complex x=a[j+k],y=w*a[j+cur+k];//推出现在的式 a[j+k]=x+y; a[j+k+cur]=x-y; } } } } int main() { //freopen("P3803_5.in","r",stdin); ll n=read(),m=read(),cur=1,dep=0; for(ll i=0;i<=n;++i)a[i].a=double(read()); for(ll i=0;i<=m;++i)b[i].a=double(read()); while(cur<=n+m)cur<<=1,++dep;//变成2的次幂 for(ll i=0;i<cur;++i)//线性递推位置,用nlogn的朴素方法也可 status[i]=(status[i>>1]>>1)|((i&1)<<(dep-1)); for(ll j=1;j<cur;j<<=1) { cosp[j]=cos(PI/j); sinp[j]=sin(PI/j); } FFT(a,cur,1);//A变为点值表示 FFT(b,cur,1);//B变为点值表示 for(ll i=0;i<cur;++i)a[i]=a[i]*b[i];//点值直接乘 FFT(a,cur,-1);//点值表示转为系数表示 for(ll i=0;i<=n+m;++i) printf("%d ",int(a[i].a/cur+0.5));//四舍五入 return 0; } ```