详解 FFT

· · 个人记录

前置知识:多项式,分治。

应用场景:多项式乘法。

温馨提示:本文证明不必掌握,仅供想要了解的人阅读。

〇、导入

您一定算过多项式乘法吧!有的时候,这算起来比较麻烦,比如:

\begin{aligned} &(x^2+2x-2)(2x^2-x+3)\\ =~&x^2(2x^2-x+3)+2x(2x^2-x+3)-2(2x^2-x+3)\\ =~&(2x^4-x^3+3x^2)+(4x^3-2x^2+6x)-(4x^2-2x+6)\\ =~&2x^4-x^3+3x^2+4x^3-2x^2+6x-4x^2+2x-6\\ =~&2x^4+(-x^3+4x^3)+(3x^2-2x^2-4x^2)+(6x+2x)-6\\ =~&2x^4+3x^3-3x^2+8x-6. \end{aligned}

\LaTeX 表示就更麻烦了。

在实际应用上,有时面对的多项式甚至多达上万项!这个时候再人工手算效率过低,且容易出错。幸好,我们已经有了计算机,能够用非常快的速度算出结果!

暴力算法是很容易想到的:

for(int i=0;i<n;++i)
    for(int j=0;j<m;++j)
        c[i+j]+=a[i]*b[j];

但它的时间复杂度为 \Theta(n^2) 级,如果遇到上万的数据就容易被卡了。

有没有更快的方法呢?当然有!那就是我们现在要讲的 FFT !

一、知识补充

1. 多项式

1-1 多项式的一般表达

我们通常用 F(x) 来表示一个多项式,定义一个多项式只需用 F(x)=a_nx^n+a_{n-1}x^{n-1}+\cdots+a_0 ,如 F(x)=x^2+3x-5 。可以把它理解为函数,比如 F(2) 就是将 x=2 代入多项式 F(x) 后的值。

1-2 多项式的点值表达

我们知道,在平面直角坐标系中,n+1 个不重合的点可以唯一确定一个一元 n 次多项式。

所以我们可以n+1 个点值来表示一个一元 n 次多项式

那么如何通过点值表达来计算多项式乘法呢?

设已知两个一元 n 次多项式 F(x),G(x) 的点值表达,W(x)=F(x)\times G(x) 。很显然,多项式 W(x)x=i 时的点值为 W(i)=F(i)\times G(i)

但是,就像前面所说的那样,n+1 个不重合的点可以唯一确定一个一元 n 次多项式,而 W(x) 显然是个一元 2n 次多项式,n+1 个点值是根本不够的。怎么办?

很简单,只需先分别算出 F(i),G(i)2n+1 个点值就行了!

1-3 多项式乘法的本质——卷积

接下来我们来思考一下乘法的本质。

容易想到,积的每一项的系数都可以表示为:

c_k=\sum\limits_{i=0}^ka_ib_{k-i}.

对于

c_k=\sum\limits_{f(i,j)=k}a_ib_{j}(f(i,j)\text{ 为某种运算})

这样的式子,我们称之为卷积。(了解更多)

可以看出,多项式乘法其实就是加法卷积。

2. 复数

2-1 定义

学复数之前,我们所接触的数仅限于实数范围内。

从自然数到整数、有理数再到实数,数的天地越来越大。现在,我们将数系扩充到复数范围内:

复数

我们用 \mathrm i 来表示虚数单位 \sqrt{-1} ,复数用 a+b\mathrm i(a,b\in\mathbf R) 表示,其中 a 为实部,b 为虚部。

我们定义复数 a+b\mathrm ia-b\mathrm i 互为共轭复数。(通常记复数 z 的共轭复数为 \overline z

复数 z=a+b\mathrm i 可以分类如下:

\text{复数}~z \left\{ \begin{aligned} \text{实数}&(b=0),\\ \text{虚数}&(b\not=0)(\text{当}~a=0~\text{时为纯虚数}). \end{aligned} \right.

2-2 几何意义

当我们还在学实数的时候,数轴是这样的:

现在又多了复数。于是,我们需要用一个平面直角坐标系来表示数轴了:

这个建立了直角坐标系来表示复数的平面叫做复平面x 轴叫做实轴y 轴叫做虚轴。显然,在实轴上的点都表示实数;除了原点外,虚轴上的点都表示纯虚数。复平面内的点 (a,b) 表示虚数 a+b\mathrm i

由上图可以看出,复数 z=a+bi 、点 Z(a,b) 和 平面向量 \overrightarrow{OZ} 一一对应。我们将向量 \overrightarrow{OZ} 的模 r 叫做复数 z 的模,记作 \lvert z\rvert 。由模的定义可知:

\lvert z\rvert=\lvert a+b\mathrm i\rvert=r =\sqrt{a^2+b^2}\quad(r\ge 0,r\in\mathbf R).

为方便起见,我们常把复数 z=a+b\mathrm i 说成点 Z 或说成向量 \overrightarrow{OZ} ,并且规定,相等的向量表示同一个复数。

辐角

z\not=0 时,向量 \overrightarrow{OZ} 与实轴正向的夹角称为复数 z辐角,记作 \operatorname{Arg}z 。辐角的符号规定为:由正实轴依反时针方向转到 \overrightarrow{OZ} 为正,依顺时针方向转到 \overrightarrow{OZ} 为负。

显然一个非零复数 z 的辐角有无穷多个值,它们相差 \pi 的整数倍,但 \operatorname{Arg}z 中只有一个值 \theta_0 满足条件 -\pi\le\theta_0\le\pi ,称为复数的主辐角,记为 \arg z ,于是

\operatorname{Arg}z=\arg z+2n\pi.

这里我们使用的是弧度制,单位为 \mathrm{rad},可以省略不写。\pi~\mathrm{rad}=180\degree.

2-3 四则运算

2-3-1 公式

\begin{aligned} (a+b\mathrm i)+(c+d\mathrm i) &=(a+c)+(b+d)\mathrm i,\\ \\(a+b\mathrm i)-(c+d\mathrm i) &=(a-c)+(b-d)\mathrm i,\\ \\(a+b\mathrm i)\times(c+d\mathrm i) &=ac+bd\mathrm i^2+ad\mathrm i+bc\mathrm i\\ &=(ac-bd)+(ad+bc)\mathrm i,\\ \\(a+b\mathrm i)\div (c+d\mathrm i) &=\frac{(a+b\mathrm i)\times(c-d\mathrm i)}{(c+d\mathrm i)\times(c-d\mathrm i)}\\ &=\frac{(ac+bd)+(bc-ad)\mathrm i}{c^2-d^2\mathrm i^2}\\ &=\frac{ac+bd}{c^2+d^2}+\frac{bc-ad}{c^2+d^2}\mathrm i. \end{aligned}

2-3-2 几何意义

复数的加减法与向量的加减法对应,很容易理解。

复数相乘时,模长相乘,辐角相加。

证明:

设有两个复数 z_1=a+b\mathrm i,z_2=c+d\mathrm i 分别对应点 A(a,b),B(c,d) ,它们的积为 z_1z_2=(ac-bd)+(ad+bc)\mathrm i ,对应点 C(ac-bd,ad+bc)

给出一个图方便大家验证:

3. 单位根

3-1 定义

我们称 n 次幂为 1 的复数为 n 次单位根,即方程

x^n=1

的复数解。

因为 n 次方程有且只有 n 个复数根,所以 n 次单位根一共有 n 个。

为了求出单位根,我们在复平面上放一个单位圆。单位圆即圆心在原点上且半径为 1 的圆。如图:

显然,单位圆是由所有模长为 1 的复数表示的点构成的。

关于一个复数 zz^n 的关系:

所以,只有模长为 1 的复数可能为 n 次单位根,即所有单位根表示的点都在单位圆上。可以发现,n 次单位根是且只能是辐角为 \frac{2k\pi}n(0\le k<n,k\in\mathbf Z)(即 \frac 1 n 圆周)的复数。

所以这 nn 次单位根 n 等分单位圆。

3-2 性质

  1. 首先已知 \arg z\omega_n^1\pi\equiv(\arg z+\frac{2\pi}n)\ (\bmod\ 2\pi)\quad(z\not=0) ,即一个非零复数每乘一次 \omega_n^1 都相当于辐角增加 \frac 1 n 圆周。
  2. 证明: $$ \begin{aligned} \because~&\arg \omega_n^k=\arg \omega_n^{k\bmod n}\\ \therefore~&\omega_n^k=\omega_n^{k\bmod n}. \end{aligned} $$
  3. 证明:由性质 $0$ 可知 $\arg(\omega_n^1)^k=k\arg\omega_n^1=\frac{2k\pi}n=\arg\omega_n^k$ ,即 $(\omega_n^1)^k$ 相当于辐角为 $\frac k n$ 圆周的复数,即 $\omega_n^k$ 。
  4. 证明:由性质 $0,1$ 可得。
  5. 证明:由性质 $2$ 可得。
  6. 证明:因为 $\arg\omega_n^k=\frac{2k\pi}{n},\arg\omega_{pn}^{pk}=\frac{2pk\pi}{pn}=\frac{2k\pi}{n}$ ,即 $\omega_{pn}^{pk}$ 的辐角为 $\frac{pk}{pn}$ 圆周,$\omega_n^k$ 的辐角为 $\frac k n$ 圆周,所以 $\arg\omega_n^k=\arg\omega_{pn}^{pk}$ 。
  7. 证明:由性质 $3$ 可得 $\omega_n^{(k+n/2)}=\omega_n^k\times\omega_n^{(n/2)}=\omega_n^k\times-1=-\omega_n^k\quad(2|n)$ 。

3-3 单位根反演

在这里只是为了验证 FFT 的正确性,不想了解可以跳过。其中

[P]= \left\{ \begin{aligned} 1&, &P\\ 0&, &\neg P \end{aligned} \right.

是不是跟 C++ 的 bool 值一样?

补充完了前置知识 ,感觉身体被掏空,让我们开始学习 FFT 的核心思想吧!

二、核心

利用 FFT 计算多项式乘法大致分为两步:DFT(离散傅里叶变换) 和 IDFT(离散傅里叶逆变换) 。其中 DFT 表示将多项式的系数表达转换为点值表达,IDFT 则是它的逆运算,即将多项式的点值表达转换为系数表达。FFT 通过分治将 DFT 加速到 \Theta(n\log n) ,能过 10^6 级的数据。

1. 思想

1-1 DFT

设有一一元 n-1(n=2^t,t\in\mathbf N) 次(即有 n 项)多项式

F(x)=\sum\limits_{i=0}^{n-1}k_ix^i,

将它的每一项安装次数的奇偶分成两部分:

F(x) =\sum\limits_{i=0}^{n/2-1}k_{2i}x^{2i} +\sum\limits_{i=0}^{n/2-1}k_{2i+1}x^{2i+1}.

设有两个多项式

\begin{aligned} F_l(x)&=\sum\limits_{i=0}^{n/2-1}k_{2i}x^i,\\ F_r(x)&=\sum\limits_{i=0}^{n/2-1}k_{2i+1}x^i, \end{aligned}

那么

F(x)=F_l(x^2)+xF_r(x^2).

\omega_n^k(k<n/2) 代入 F(x) 中,得

\begin{aligned} F(\omega_n^k) &=F_l((\omega_n^k)^2)+\omega_n^kF_r((\omega_n^k)^2)\\ &=F_l(\omega_{n/2}^k)+\omega_n^kF_r(\omega_{n/2}^k). \end{aligned}

\omega_n^{k+n/2}(k<n/2) 代入 F(x) 中,得

\begin{aligned} F(\omega_n^{k+n/2}) &=F_l((\omega_n^{k+n/2})^2)+\omega_n^{k+n/2}F_r((\omega_n^{k+n/2})^2)\\ &=F_l(\omega_n^{2k+n})-\omega_n^kF_r(\omega_n^{2k+n})\\ &=F_l(\omega_n^{2k})-\omega_n^kF_r(\omega_n^{2k})\\ &=F_l(\omega_{n/2}^k)-\omega_n^kF_r(\omega_{n/2}^k). \end{aligned}

如果我们知道 F_l(x)F_r(x)\omega_{n/2}^0,\omega_{n/2}^1,\cdots,\omega_{n/2}^{n/2-1} 上的点值表达,那么可以用

F(\omega_n^k) =F_l(\omega_{n/2}^k)+\omega_n^kF_r(\omega_{n/2}^k) \quad(k<n/2)

来求 F(x)\omega_n^0,\omega_{n/2}^1,\cdots,\omega_n^{n/2-1} 上的点值表达,用

F(\omega_n^{k+n/2}) =F_l(\omega_{n/2}^k)-\omega_n^kF_r(\omega_{n/2}^k) \quad(k<n/2)

来求 F(x)\omega_n^{n/2},\omega_{n/2}^1,\cdots,\omega_n{n-1} 上的点值表达。

也就是说,如果我们知道 F_l(x)F_r(x)\omega_{n/2}^0,\omega_{n/2}^1,\cdots,\omega_{n/2}^{n/2-1} 上的点值表达,能够 \Theta(n) 求出 F(x)\omega_n^0,\omega_{n/2}^1,\cdots,\omega_n^{n-1} 上的点值表达。

但问题在于:我们不知道 F_l(x)F_r(x)\omega_{n/2}^0,\omega_{n/2}^1,\cdots,\omega_{n/2}^{n/2-1} 上的点值表达啊?

这时候就要用到分治了!

可以想到,用同样的方式处理 F_l(x)F_r(x),直到处理的多项式为常数(即只有一项)时就不需要再操作了。

这种通过分治加速 DFT 的方法即为 FFT 。

1-2 IDFT

设多项式 F(x)=\sum\limits_{i=0}^{n-1}a_ix^i\omega_n^k 上的点值表达为 b_k ,则

b_k=F(\omega_n^k) =\sum\limits_{i=0}^{n-1}a_i(\omega_n^k)^i

只需记住

a_k=\frac 1 n\sum\limits_{i=0}^{n-1}b_i(\omega_n^{-k})^i

就可以了!

所以,IDFT 相当于 DFT 用 \omega_n^{-k} 代替 \omega_n^k 后再除以 n

证明:(需要用到单位根反演)

\begin{aligned} a_k &=\frac 1 n\sum\limits_{i=0}^{n-1}b_i(\omega_n^{-k})^i\\ &=\frac 1 n\sum\limits_{i=0}^{n-1} \left[ \sum\limits_{j=0}^{n-1}a_j(\omega_n^i)^j \right] (\omega_n^{-k})^i\\ &=\frac 1 n\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{n-1}a_j\omega_n^{ij-ik}\\ &=\frac 1 n\sum\limits_{j=0}^{n-1}a_j\sum\limits_{i=0}^{n-1}\omega_n^{i(j-k)}\\ &=\frac 1 n\sum\limits_{j=0}^{n-1}a_jn[j=k]\\ &=a_k. \end{aligned}

2. 初步实现

2-1 多项式

我们需要一个复数数组 a[n] 来储存一个 n-1 次多项式。用 a[i] 储存一个多项式的 i 次系数或在 \omega_n^i 上的点值表达。

前面说过,n=2^t,t\in\mathbf N ,即 n2 的整数次幂。但在实际应用中,n 很可能不为 2 的整数次幂。怎么办?

很简单,如果它不为 2 的整数次幂,那就将它向上补到 2 的整数次幂,再将多出来的项的系数赋值为 0

2-2 复数

STL 自带复数库 complex ,但为防止被卡,最好自己手写一个复数结构体。在 FFT 中我们只需要用到加减乘即可。

struct Complex
{
    double a,b;
    Complex(){}
    Complex(const double &ia,const double &ib){a=ia,b=ib;}
    // 重载运算符
    inline Complex operator+(const Complex &x){return Complex(a+x.a,b+x.b);}
    inline Complex operator-(const Complex &x){return Complex(a-x.a,b-x.b);}
    inline Complex operator*(const Complex &x){return Complex(a*x.a-b*x.b,a*x.b+b*x.a);}
};

2-3 单位根

这里需要一点三角函数的知识。

由单位根的性质可知

\omega_n^k= \left\{ \begin{aligned} &\omega_n^{k+1}\times\omega_n^{-1}, &k<-1\\ &(\cos \frac{2\pi}n,-\sin \frac{2\pi}n), &k=-1\\ &1, &k=0\\ &(\cos \frac{2\pi}n,\sin \frac{2\pi}n), &k=1\\ &\omega_n^{k-1}\times\omega_n^1, &k>1 \end{aligned} \right.

由此我们可以通过递推求出 \omega_n^k

另外,为了更高的精度,我们这样定义 \pi

const double pi=acos(-1.0);

2-4 递归实现 FFT

F(\omega_n^k) &=F_l(\omega_{n/2}^k)+\omega_n^kF_r(\omega_{n/2}^k),\\ F(\omega_n^{k+n/2}) &=F_l(\omega_{n/2}^k)-\omega_n^kF_r(\omega_{n/2}^k),\\ a_k &=\frac 1 n\sum\limits_{i=0}^{n-1}b_i(\omega_n^{-k})^i. \end{aligned}

对比公式和代码有助于更好地理解。

void fft(Complex *&a,const int n,const bool &inv)
{
    if(n==1)    return;
    const int bn=n>>1;  // n/2
    Complex *la=a,*ra=a+bn;
    for(int k=0;k<n;++k)    tmp[k]=bg[k];
    // 根据次数的奇偶分成两部分
    for(int k=0;k<bn;++k)   la[k]=tmp[k<<1],ra[k]=tmp[(k<<1)|1];
    // 分治
    fft(bg,ra,inv),fft(ra,ed,inv);
    // w1 为第 1 个 n 次单位根,wk 为第 k 个 n 次单位根
    Complex w1(cos(2.0*pi/n),sin(2.0*pi/n)),wk(1.0,0.0),t;
    if(inv) w1.b=-w1.b;
    for(int k=0;k<bn;++k)
    {
        t=wk*ra[k];     // 由于复数乘法较慢,这里我用一个临时变量储存积
        tmp[k]=la[k]+t;
        tmp[k+bn]=la[k]-t;
        wk=wk*w1;
    }
    for(int k=0;k<n;++k)    a[k]=tmp[k];
}

2-5 蝴蝶变换

我们来看分治过程中各项的原下标:

0 1 2 3 4 5 6 7
0 2 4 6|1 3 5 7
0 4|2 6|1 5|3 7
0|4|2|6|1|5|3|7

下标变化如下:

0->0
1->4
2->2
3->6
4->1
5->5
6->3
7->7

可以发现是两两交换:

0<->0
1<->4
2<->2
3<->6
5<->5
7<->7

然而貌似并没有什么用……

转换成二进制试试:

000<->000
001<->100
010<->010
011<->110
101<->101
111<->111

发现玄机了吧? ????

显然就是将下标的二进制翻转了!

用递推搞定:

for(int i=0;i<N;++i)
    rev[i]=(rev[i>>1]>>1)|((i&1)?(n>>1):0);

仔细看就能搞懂了。

2-6 合并优化

对于

for(int k=0;k<bn;++k)
{
    t=wk*ra[k];
    tmp[k]=la[k]+t;
    tmp[k+bn]=la[k]-t;
    wk=wk*w1;
}
for(int k=0;k<n;++k)    a[k]=tmp[k];

这部分,由于 la=a,ra=a+bn ,所以可以替换为

for(int k=0;k<bn;++k)
{
    t=wk*ra[k];
    // 注意:下面两条语句的顺序已交换,请按照此顺序编写
    a[k+bn]=la[k]-t;
    a[k]=la[k]+t;
    wk=wk*w1;
}

这样就可以避免大量的数组拷贝。

现在 FFT 实现如下:

void fft(Complex *&a,const int n,const bool &inv)
{
    if(n==1)    return;
    const int bn=n>>1;
    Complex *la=a,*ra=a+bn;
    fft(bg,ra,inv),fft(ra,ed,inv);
    Complex w1(cos(2.0*pi/n),sin(2.0*pi/n)),wk(1.0,0.0),t;
    if(inv) w1.b=-w1.b;
    for(int k=0;k<bn;++k)
    {
        t=wk*ra[k];
        a[k+bn]=la[k]-t;
        a[k]=la[k]+t;
        wk=wk*w1;
    }
}

请在 FFT 之前交换二进制翻转的下标:

for(int i=0;i<n;++i)
    if(i<rev[i])
        swap(a[i],a[rev[i]]);

2-7 迭代实现 FFT

从下往上合并,这样可以减少许多不必要的操作,使得时间得到进一步优化:

void fft(Complex *a,const bool &inv)
{
    // 蝴蝶变换
    for(int i=0;i<N;++i)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    // 枚举 n
    for(int n=2;n<=N;n<<=1)
    {
        const int bn=n>>1;
        Complex w1(cos(2.0*pi/n),sin(2.0*pi/n));
        if(inv) w1.b=-w1.b;
        for(int l=0;l<N;l+=n)
        {
            Complex wk(1.0,0.0),t,*la=a+l,*ra=a+l+bn;
            for(int k=0;k<bn;++k)
            {
                t=wk*ra[k];
                // 这里用 la 替代 a
                la[k+bn]=la[k]-t;
                la[k]=la[k]+t;
                wk=wk*w1;
            }
        }
    }
}

三、应用

费了那么大的劲弄出来的 FFT,总不能仅仅是 DFT 后再 IDFT 吧?那不就什么也没干。让我们来应用吧!

1. 多项式乘法

模板题:P3803 【模板】多项式乘法(FFT)。

还记得之前说过的吗?

设已知两个一元 n 次多项式 F(x),G(x) 的点值表达,W(x)=F(x)\times G(x) 。很显然,多项式 W(x)x=i 时的点值为 W(i)=F(i)\times G(i)

所以,我们分别对两个多项式 DFT 一遍,然后将两者的点值的积 IDFT 一下就行了!

我的 AC 代码:

#include <cmath>
#include <cstdio>
const double pi=acos(-1.0);
int N,M,rev[0x200005];
struct Complex
{
    double a,b;
    Complex(){}
    Complex(const double &ia,const double &ib){a=ia,b=ib;}
    inline Complex operator+(const Complex &x){return Complex(a+x.a,b+x.b);}
    inline Complex operator-(const Complex &x){return Complex(a-x.a,b-x.b);}
    inline Complex operator*(const Complex &x){return Complex(a*x.a-b*x.b,a*x.b+b*x.a);}
}fa[0x200005],ga[0x200005];
void swap(Complex &x,Complex &y){Complex t=x;x=y;y=t;}
int getint()
{
    char c;
    for(;c=getchar(),c<'0' || c>'9';);
    return c&0xf;
}
void putint(const int num)
{
    if(num>=10)
        putint(num/10);
    putchar(num%10|'0');
}
void fft(Complex *a,const bool &inv)
{
    for(register int i=0;i<N;++i)   if(i<rev[i])    swap(a[i],a[rev[i]]);
    for(register int n=2,l,k;n<=N;n<<=1)
    {
        const register int bn=n>>1;
        register Complex w1(cos(2.0*pi/n),sin(2.0*pi/n));
        if(inv) w1.b=-w1.b;
        for(l=0;l<N;l+=n)
        {
            register Complex wk(1.0,0.0),t,*la=a+l,*ra=a+l+bn;
            for(k=0;k<bn;++k)
            {
                t=wk*ra[k];
                ra[k]=la[k]-t;
                la[k]=la[k]+t;
                wk=wk*w1;
            }
        }
    }
}
inline void init()
{
    register int i;
    scanf("%d%d",&N,&M);
    for(i=0;i<=N;++i)   fa[i].a=getint();
    for(i=0;i<=M;++i)   ga[i].a=getint();
    for(M+=N,N=1;N<=M;N<<=1);
    for(i=0;i<N;++i)    rev[i]=(rev[i>>1]>>1)|((i&1)?(N>>1):0);
}
inline void solve()
{
    fft(fa,false),fft(ga,false);
    for(register int i=0;i<N;++i)   fa[i]=fa[i]*ga[i];
    fft(fa,true);
    for(register int i=0;i<=M;++i)  putint((int)(fa[i].a/N+0.49)),putchar(' ');
}
int main()
{
    init();
    solve();
    return 0;
}

2. 高精度乘法

模板题:P1919 【模板】A*B Problem升级版(FFT快速傅里叶)。

高精度乘法其实就是将多项式乘法稍稍改动了而已。

设有两个自然数 a,b ,两个数字分别有 n,m 位。设存在一元 n 次多项式 F(x) 和一元 m 次多项式 G(x) ,使 F(10)=a,G(10)=b ,则 a\times b=(F\times G)(10)

The End

参考: