FFT学习笔记
万弘
2020-01-20 10:14:47
## FFT学习笔记
FFT,即快速傅里叶变换,用于在$O(nlogn)$的时间内计算两个多项式的卷积。显然暴力计算是$O(n^2)$的。
前置技能:复数基本运算(至少你要知道,复数可以用复平面上的向量表示),多项式基础。
众所周知,多项式有两种表示方法:
1. 系数表示,即$f(x)=\sum_{i=0}^na_ix^i$,这也是最常用的方法
2. 点值表示,将$n$互不相同的$x$带入系数表示,会得到$n$个相应的$y$,因此得到$n$个点$(x_1,y_1),(x_2,y_2)...(x_n,y_n)$.与三点定一抛物线类似,这$n$个点可以唯一确定一个$n-1$次多项式.
因为$(f\times g)(x)=f(x)\times g(x)$,所以若已知两个多项式的点值表示,我们可以直接$O(n)$乘起来,得到其卷积的点值表示。
那么回到原问题,我们已知两个系数表示的多项式,如何在低于暴力的时间复杂度内求出其卷积的系数表示?
科学家们给出了方法:将给定的两个系数表示转化为两个点值表示(DFT,离散傅里叶变换),用上面的方法直接乘,得到卷积的点值表示,再转化为卷积的系数表示(IDFT,逆离散傅里叶变换)
## DFT
给一个$n$次(默认$n$是2的次幂,不足高位补0即可)多项式,如何求出其点值表示?
$$f(x)=\sum_{i=0}^{n-1}a_ix^i$$
按指数的奇偶性分类,
$$f(x)=(a_0+a_2x^2+a_4x^4+..+a_{n-2}x^{n-2})+(a_1x+a_3x^3+..+a_{n-1}x^{n-1}$$
记$A1(x)=a_0+a_2x+a_4x^2+..a_{n-2}x^{n/2-1},A2(x)=a_1+a_3x+..+a_{n-1}x^{n/2-1}$
则有
$$f(x)=A1(x^2)+xA2(x^2)$$
为什么要这样做?可以发现,$A1,A2$的次数都是$n/2-1$次,我们缩小了问题的规模,递归地求出他们的点值式,并拼起来即可得到$f(x)$。还有两个问题:如何取点?如何拼?
引入**单位根**
引一段attack大佬的话:![](https://cdn.luogu.com.cn/upload/image_hosting/ayevrzgq.png)
单位根的有用性质:
1:$w_n^k=w_{2n}^{2k}$,因为$w_{2n}^{2k}=\cos 2k\times \frac{2\pi}{2n}+i\sin 2k\times \frac{2\pi}{2n}=w_n^k$
2:$w_0^0=w_n^n=1$
3:$w_n^{n/2}=-1$.你看上面那个圆,分成上部和下部,一半就是-1啊。
好我们回到原题。
$$f(x)=A1(x^2)+xA2(x^2)$$
令$x=w_n^k(k<n/2),f(x)=A1(w_n^{2k})+w_n^kA2(w_n^{2k})$
令$x=w_n^{k+n/2},f(x)=A1(w_n^{2k+n})+w_n^{k+n/2}A2(w_n^{2k+n})=A1(w_n^{2k})-w_n^kA2(w_n^{2k})$
可以发现,两式仅有一个常数不同,因此知一式即可$O(1)$求二式
由于$A1,A2$是$n/2-1$次的,将$k\in [0,n/2-1]$带入,即可$O(n)$完成“拼”的过程。
所以DFT的时间复杂度$T(n)=2T(n/2)+O(n)=O(nlogn)$
求得点值表示后,我们直接$O(n)$把$f(x),g(x)$的点值表示乘起来即可得到其卷积的点值表示
## IDFT
给定一个点值表示,求其系数表示。
记$f(x)=\sum_{i=0}^{n-1}a_ix^i$在$(w_n^0,w_n^1..w_n^n)$处点值表示为$(y_0,y_1,y_2..y_{n-1})$
设向量$(c_0,c_1,c_2...c_{n-1})$满足$c_k=\sum_{i=0}^{n-1}y_i(w_n^{-k})^i$
推一波式子:
$$c_k=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_jw_n^{ij}(w_n^{-k})^i=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_jw_n^{ij}w_n^{-ik}$$
$$=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(w_n^{j-k})^i$$
设$g(x)=\sum_{i=0}^{n-1}x^i$。令$x=w_n^k$,则
$$g(w_n^k)=1+w_n^k+w_n^2k+..w_n^{(n-1)k}$$
当公比$w_n^k\ne1,$等比数列求和可得
$$g(w_n^k)=\frac{w_n^kw_n^{(n-1)k}-1}{w_n^k-1}=\frac{1-1}{w_n^k-1}=0$$
若$w_n^k=1,k=0$或$k=n$。所以$g(x)=\sum_{i=0}^{n-1}x^i=n$
带入
$$c_k=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(w_n^{j-k})^i$$
$$=na_k$$
$$\therefore a_k=\frac{c_k}{n}$$
求出$c_k$即可。
小结:FFT,就是讲两个系数表示转化为点值表示,直接乘,再转化为系数表示。虽然绕了一大圈,但时间复杂度却是$\Theta (nlogn)$
## 代码实现
听起来这么nb,该如何实现呢?
由上面的思路,写出的递归版代码是这样的:
```cpp
complex a[MAXN],b[MAXN];//complex是自定义的复数类,.a为实部,.b为虚部
void FFT(complex* a,ll len,ll type)//正在处理的式,长度,type=1表示DFT,type=-1表示IDFT(因为他们只有一个系数不同)
{
if(len==1)return;
len>>=1;//现在的len是划分成区间的长度
complex a0[len],a1[len];//左区间,右区间
for(ll k=0;k<len;++k)
a0[k]=a[k<<1],a1[k]=a[k<<1|1];
FFT(a0,len,type);FFT(a1,len,type);//分治
complex wn(cos(M_PI/len),sin(M_PI/len)*type),w(1,0);//wn1,wn0
for(ll k=0;k<len;++k,w=w*wn)//wnk
{
complex tmp=w*a1[k];
a[k]=a0[k]+tmp,a[k+len]=a0[k]-tmp;//统计
}
}
```
嗯思路清晰!结果,Fast Fast TLE!
唔..常数太大了
考虑用迭代实现FFT.
![](https://cdn.luogu.com.cn/upload/image_hosting/lx18kv5c.png)
可以发现,每个位置的变化就是将二进制反转了。那显然我们就可以$\Theta(nlogn)$预处理出位置的变化(还可以线性递推,详见代码),然后,将信息自下而上合并,即可迭代实现FFT,常数更小。
最后是迭代实现的FFT完整代码,时间复杂度$\Theta(nlogn)$
```cpp
/**********/省略快读
#define MAXN 4000011
const double PI=M_PI;
struct complex//复数类
{
double a,b;
complex(double aa=0,double bb=0)
{
a=aa,b=bb;
}
complex operator +(const complex& t)
{
return complex(a+t.a,b+t.b);
}
complex operator -(const complex& t)
{
return complex(a-t.a,b-t.b);
}
complex operator *(const complex& t)
{
return complex(a*t.a-b*t.b,a*t.b+b*t.a);
}
};
complex a[MAXN],b[MAXN];
double cosp[MAXN],sinp[MAXN];//预处理三角函数,sinp[i]=sin(PI/i)
ll status[MAXN];//最终状态
void FFT(complex* a,ll len,ll type)//迭代fft.原式,长度,type=1表示DFT,type=-1表示IDFT
{
if(len==1)return;
for(ll i=0;i<len;++i)
if(status[i]>i)std::swap(a[i],a[status[i]]);//交换
for(ll cur=1;cur<len;cur<<=1)//现在的长度
{
complex wn(cosp[cur],sinp[cur]*type);//wn1
for(ll i=cur<<1,j=0;j<len;j+=i)//j是当前位置,i是长度的两倍
{
complex w(1,0);//wn0
for(ll k=0;k<cur;++k,w=w*wn)
{
complex x=a[j+k],y=w*a[j+cur+k];//推出现在的式
a[j+k]=x+y;
a[j+k+cur]=x-y;
}
}
}
}
int main()
{
//freopen("P3803_5.in","r",stdin);
ll n=read(),m=read(),cur=1,dep=0;
for(ll i=0;i<=n;++i)a[i].a=double(read());
for(ll i=0;i<=m;++i)b[i].a=double(read());
while(cur<=n+m)cur<<=1,++dep;//变成2的次幂
for(ll i=0;i<cur;++i)//线性递推位置,用nlogn的朴素方法也可
status[i]=(status[i>>1]>>1)|((i&1)<<(dep-1));
for(ll j=1;j<cur;j<<=1)
{
cosp[j]=cos(PI/j);
sinp[j]=sin(PI/j);
}
FFT(a,cur,1);//A变为点值表示
FFT(b,cur,1);//B变为点值表示
for(ll i=0;i<cur;++i)a[i]=a[i]*b[i];//点值直接乘
FFT(a,cur,-1);//点值表示转为系数表示
for(ll i=0;i<=n+m;++i)
printf("%d ",int(a[i].a/cur+0.5));//四舍五入
return 0;
}
```