你能在不会虚数的情况下通关 FFT 吗?

· · 算法·理论

你能在不会虚数的情况下通关 FFT 吗?

0.前言

本文主要是为了萌新 OIer 快速入门 FFT 而准备的,所以大部分证明将会跳过,想看证明请去别的文章。

众所周知 OI 不考证明。如果你觉得能在考场写出 FFT 并正确使用就够了,并不想了解详细证明,那么这篇文章可能适合你。

前置知识:多项式与生成函数,插值。

(如果你不会生成函数就没必要学 FFT 了)

1.应用

FFT 主要是用来解决多项式乘法

2.原理

对于一个 n 次多项式,如果你知道 n+1 个点值,就能确定这个多项式。

一个 n 次多项式 F(x) 和一个 m 次多项式 G(x) 相乘的结果是一个 n+m 次多项式 H(x),如果我们知道 n+m+1 个不同的点值 (x_i,F(x_i))(x_i,G(x_i)),那么我们就能求出 (x_i,H(x_i))=(x_i,F(x_i)\times G(x_i)),插值得到 H(x)。乘法这一步是 O(n) 的,问题转化成了如何快速点值和插值。

3.实现

为了方便后面的分治操作,我们可以在 F(x),G(x),H(x) 后面加若干系数为 0 的项,使得他们都变成 N=2^k 的多项式。(N-1 次多项式)

为了快速点值和插值,如何选择 x_i 是关键。

定义一个虚数 \omega_n,使得 \omega_n^n=1。你可以理解为 \omega_n=1^{\frac{1}{n}}(看起来很奇怪但它都是虚数了就顺从它吧)

性质:\omega_{2n}^{2k}=\omega_n^k。你可以理解为 \omega_{2n}^{2k}=1^{\frac{2k}{2n}}=1^{\frac{k}{n}}=\omega_n^k

值:\omega_{n}=\cos(\frac{2\pi}{n})+\sin(\frac{2\pi}{n})i(直接记就行)

我们的目标是求出 F(x)\omega_N^0,\omega_N^1,\dots,\omega_N^{N-1} 上的值。G(x) 同理。

设:

F(x)=a_0+a_1x_1+a_2x^2+\dots+a_{N-1}x^{N-1}

定义:

F^{[0]}(x)=a_0+a_2x^1+a_4x^2+\dots+a_{N-2}x^{\frac{N-2}{2}} F^{[1]}(x)=a_1+a_3x^1+a_5x^2+\dots+a_{N-1}x^{\frac{N-2}{2}} \therefore F(x)=F^{[0]}(x^2)+xF^{[1]}(x^2)

问题转化为求 F^{[0]}(x)\omega_{\frac{N}{2}}^0,\omega_{\frac{N}{2}}^1,\dots,\omega_{\frac{N}{2}}^{N-1} 的值。F^{[1]} 同理。

根据定义,我们发现 \omega_{\frac{N}{2}}^0=\omega_{\frac{N}{2}}^{\frac{N}{2}},\omega_{\frac{N}{2}}^1=\omega_{\frac{N}{2}}^{\frac{N}{2}+1},\dots,\omega_{\frac{N}{2}}^{\frac{N}{2}-1}=\omega_{\frac{N}{2}}^{N-1},所以只要求出 F^{[0]}(x)\omega_{\frac{N}{2}}^0,\omega_{\frac{N}{2}}^1,\dots,\omega_{\frac{N}{2}}^{\frac{N}{2}-1} 的值,就能求出 F^{[0]}(x)\omega_{\frac{N}{2}}^{\frac{N}{2}},\omega_{\frac{N}{2}}^{\frac{N}{2}+1},\dots,\omega_{\frac{N}{2}}^{N-1} 的值。

所以我们可以通过不断将 a 分组,然后就可以递归求解了。

你问我怎么插值?把 \omega_n 换成 \cos(\frac{2\pi}{n})-\sin(\frac{2\pi}{n})i,然后把结果除以 N 就行了。(证明跳过)

递归代码

void FFT(int N,complex *a){
    if(limit==1)return;
    complex a1[N>>1],a2[N>>1];
    for(int i=0;2*i<=N;i++)a1[i]=a[i<<1],a2[i]=a[i<<1|1];
    FFT(N>>1,a1);
    FFT(N>>1,a2);
    complex wn=complex(cos(2.0*Pi/N),sin(2.0*Pi/N)),w=complex(1,0);
    for(int i=0;i<(N>>1);i++,w*=wn)a[i]=a1[i]+w*a2[i],a[i+(N>>1)]=a1[i]-w*a2[i];
}

但是这种方法需要动态开区间,导致常数不可避免的大。

4.优化

我们观察 a 经过分组后会变成什么样:

(图片来自 FlashHu 大佬的博客)

观察二进制:

000 001 010 011 100 101 110 111

变为:

000 100 010 110 001 101 011 111

我们发现就是上面的二进制翻转后的结果。

所以我们可以 O(n) 求出翻转后的结果 r

for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(k-1));

接下来只需从下到上不断合并即可。具体实现可以看看代码。

5.完整代码

#include<bits/stdc++.h>
#define C complex<double>
using namespace std;
const int N=4e6+10;
const double PI=acos(-1);
int n,m,r[N],l;
C f[N],g[N];
void FFT(C *a,int op){
    for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int i=1;i<n;i<<=1){
        C wn{cos(PI/i),sin(PI/i)*op};
        for(int j=0;j<n;j+=(i<<1)){
            C w{1,0};
            for(int k=0;k<i;k++,w*=wn){
                C x=a[j+k],y=a[j+k+i];
                a[j+k]=x+w*y;
                a[j+k+i]=x-w*y;
            }
        }
    }
}
int main(){
    cin>>n>>m;
    for(int i=0,x;i<=n;i++)cin>>x,f[i]=x;
    for(int i=0,x;i<=m;i++)cin>>x,g[i]=x;
    m+=n;n=1;l=-1;
    while(n<=m)n<<=1,l++;
    for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)<<l);
    FFT(f,1);FFT(g,1);
    for(int i=0;i<n;i++)f[i]*=g[i];
    FFT(f,-1);
    for(int i=0;i<=m;i++)printf("%.0lf ",fabs(f[i].real()/n));
    return 0;
}