【多项式】快速傅里叶变换(FFT)【算法竞赛入门笔记】

· · 算法·理论

快速傅里叶变换

导入

这是一个改变了人类世界的算法,其历史甚至可以追溯到 C. F. Gauss(就是最有名的那个高斯)(1805年)。今天,我们就从多项式的视角来看这个算法。

我们想要计算两个多项式的乘积。举个例子:

(2x^2+x+3)(x^2+3x+5)

我们肯定都会计算为

&\phantom{{}={}}(2x^2+x+3)(x^2+3x+5)\\ &=2x^2(x^2+3x+5)+x(x^2+3x+5)+3(x^2+3x+5)\\ &=(2x^4+6x^3+10x^2)+(x^3+3x^2+5x)+(3x^2+9x+15)\\ &=2x^4+7x^3+16x^2+14x+15 \end{aligned}

这类算法的时间复杂度在 \Theta(d^2),其中 d 为两个多项式的次数(假设同阶)。

有没有更快的算法?

多项式的表示

如何表示一个多项式 A(x) 呢?

一个显然的思路是,把次数从 0 开始从低到高的每一项在一个向量内存储。比如 2x^2+x+3 可以表示为 (3,1,2)2x^4+7x^3+16x^2+14x+15 可以表示为 (15,14,16,7,2)

另一种方式是通过一些数值 x_i 和他们对应的值 A(x_i) 来表示。那么需要多少对数值?初中的时候,我们学习一次、二次多项式函数的时候常常分别把 2 个、3个点代进去来求解析式。实际上,对于一个 n-1 次的多项式,需要 n 对数值来确定。

略证:

这个问题就相当于解矩阵方程

\begin{bmatrix} 1&x_0&x_0^2&\cdots &x_0^{n-1}\\ 1&x_1&x_1^2&\cdots &x_1^{n-1}\\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1&x_d&x_d^2&\cdots &x_{n-1}^{n-1}\\ \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ \vdots\\ a_{n-1}\\ \end{bmatrix}= \begin{bmatrix} y_0\\ y_1\\ \vdots\\ y_{n-1}\\ \end{bmatrix}

只要证最左侧这个矩阵 V 可逆。这样的话,就有

最左侧这个矩阵叫做 Vandermonde 矩阵。它的行列式为 $\prod\limits_{0\le j<k\le n-1}(x_k-x_j)

显然 x_i 互不相同。因此这个行列式不为 0,故矩阵可逆。\blacksquare

因此一个次数小于 n 的多项式 A(x) 可以表示为 n 个数值对构成的集合

\{(x_0,y_0),(x1,y1),\cdots,(x_{n-1},y_{n-1})\}

其中所有 x_i 互不相同,且 y_i=A(x_i)。这种表示方法称为点值表示

回到多项式乘法

看起来点值表示还有点搞头。

我们是否可以把两个多项式表达为点值表示,然后分别把 y 相乘,得到新的点值表示?

这个问题分为三步。

  1. 把系数表示换为点值表示。
  2. 把点值表示的 y 分别乘起来。
  3. 把点值表示换为系数表示。

第二步没什么问题,重点就是如何快速进行第一步和第三步。先考虑第一步。

一个平凡的想法是随便取几个值来计算。但是这样的话时间复杂度又回到了 \Theta(n^2)。因此这个值不能随便取。

接下来,一些非常伟大的数学家想到了一个极为奇妙的方法。

假设有一个多项式有 n 项,其中 n2 的幂次。它的系数向量形如 (a0,a1,\cdots,a_{n-1})。我们把它拆成两部分:

A^{[1]}(x)=a_1+a_3x+a_5x^2+\cdots+a_{n-1}x^{n/2-1}

显然 A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)

于是为了求解 A(x) 的点值表示,我们只要求出 A^{[0]}(x^2)A^{[1]}(x^2) 的点值表示。很显然我们只需要使用递归算法来求解!

——吗?

我们假设要求 A(x)x_0, x_1, \cdots,x_{n-1} 处的值。那么问题转化为了:

怎么办?有没有一堆数的集合,使得其中每个数的平方的集合的规模是它的规模的一半,而且这个性质仍然保留? 答案呼之欲出:使用复数。如何构造这样的复数呢? 不难想到,我们定义 $\omega_n\coloneqq \mathrm{e}^{2\pi \mathrm{i}/n}=\cos(2\pi/n)+\mathrm{i}\sin(2\pi/n)$。考虑它的幂次: 例如, $\omega_8^0=1 \\ \omega_8^1=\frac{\sqrt{2}}{2}+\frac{\sqrt{2}}{2}\mathrm{i} \\ \omega_8^2=\mathrm{i} \\ \omega_8^3=-\frac{\sqrt{2}}{2}+\frac{\sqrt{2}}{2}\mathrm{i} \\ \omega_8^4=-1 \\ \omega_8^5=-\frac{\sqrt{2}}{2}-\frac{\sqrt{2}}{2}\mathrm{i} \\ \omega_8^6=-\mathrm{i} \\ \omega_8^7=\frac{\sqrt{2}}{2}-\frac{\sqrt{2}}{2}\mathrm{i} \\

注意它们的平方:

(\omega_8^1)^2=\mathrm{i} \\ (\omega_8^2)^2=-1 \\ (\omega_8^3)^2=-\mathrm{i} \\ (\omega_8^4)^2=1 \\ (\omega_8^5)^2=\mathrm{i} \\ (\omega_8^6)^2=-1 \\ (\omega_8^7)^2=-\mathrm{i} \\

\omega_4^1=\mathrm{i} \\ \omega_4^2=-1 \\ \omega_4^3=-\mathrm{i}

我们猜想:n 次单位复数根(n>0 为偶数)的平方的集合,就是 n/2 次单位复数根的平方的集合。

略证:

因为

(\omega_n^k)^2=(\mathrm{e}^{2\pi \mathrm{i}/n})^{2k}=(\mathrm{e}^\frac{2\pi \mathrm{i}}{n/2})^k=\omega_{n/2}^k

(\omega_n^{k+n/2})^2=\omega_n^{2k+n}=\omega_n^{2k}\omega_n^n=(\omega_n^k)^2

因此每个 n/2 次单位复数根恰好出现两次。\blacksquare

我们发现,只要 n2 的幂次,n 次单位复数根就有刚才我们梦寐以求的那个性质!

离散傅里叶变换(Discrete Fourier Transform,DFT)

我们的目标很明确了:计算次数小于 n 的多项式 A(x)\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1} 处的值。

A(x) 的系数向量为 (a_0,a_1,\cdots,a_{n-1})。对于 k=0,1,\cdots,n-1,定义

y_k\coloneqq A(\omega_n^k)=\sum_{j=0}^{n-1}a_j\omega_n^{kj}

并把向量 y=(y_0,y_1,\cdots,y_{n-1}) 称为向量 a离散傅里叶变换

快速傅里叶变换(Fast Fourier Transform,FFT)

我们很容易得到一个递归算法来计算一个向量 a=(a_0,a_1,\cdots,a_{n-1}) 的 DFT,其中 n2 的幂次:

  1. 求出 A^{[0]}(x)A^{[1]}(x)(\omega_n^0)^2,(\omega_n^1)^2,\cdots,(\omega_n^{n-1})^2(这里只有 n/2 个不同的值!)处的值。
  2. 通过 A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2) 合并结果。
  3. 边界:当 n=1,返回 a。(你自己核实)

用 C++ 语言书写一下:

#include <iostream>
#include <cstdio>
#include <vector>
#include <cmath>
#include <complex>
const long double pi = 3.14159265358979323846;
std::vector <std::complex<long double>> FFT(std::vector <std::complex<long double>> a, int n) {
    if(n == 1) return a;
    std::complex <long double> Wn = {std::cos(2 * pi / n), std::sin(2 * pi / n)};
    std::complex <long double> w = 1.0;
    std::vector <std::complex <long double>> a0(n >> 1), a1(n >> 1);
    for(int i = 0; i < (n >> 1); i++)
        a0[i] = a[i << 1], a1[i] = a[(i << 1) | 1];
    std::vector <std::complex<long double>> y0 = FFT(a0, n >> 1);
    std::vector <std::complex<long double>> y1 = FFT(a1, n >> 1);
    std::vector <std::complex<long double>> y(n);
    for(int i = 0; i < (n >> 1); i++, w = w * Wn) { //这一段就是合并,以下会细讲
        y[i] = y0[i] + y1[i] * w; 
        y[i + (n >> 1)] = y0[i] - y1[i] * w; 
    }
    return y;
}
int main() {
    int n;
    std::cin >> n;
    int l = 1;
    while(l < n) l <<= 1;
    std::vector <std::complex <long double>> a(l);
    for(int i = 0; i < n; i++) {
        int t;
        std::cin >> t;
        a[i] = (long double)t;
    }
    std::vector <std::complex<long double>> fa = FFT(a, l);
    for(int i = 0; i < l; i++) printf("(%Lf,%Lf)", fa[i].real(), fa[i].imag());
    return 0;
}

列出运行时间的递推式子:T(n)=2T(n/2)+\Theta(n)。根据主定理,T(n)=\Theta(n\lg n)

关于合并:

以下把我们计算出的 A^{[0]}(x) 的系数向量 DFT 称为 y^{[0]}A^{[1]}(x) 的系数向量 DFT 称为 y^{[1]}

对于 y_0,y_1,\cdots,y_{n/2-1},我们计算 y_k^{[0]}+\omega_n^ky_k^{[1]},而

y_k&=y_k^{[0]}+\omega_n^ky_k^{[1]}\\ &=A^{[0]}(\omega_n^{2k})+\omega_n^kA^{[1]}(\omega_n^{2k})\\ &=A(\omega_n^k) \end{aligned}

对于 y_{n/2},y_{n/2+1},\cdots,y_{n-1},有

y_{k+(n/2)}&=y_k^{[0]}-\omega_n^ky_k^{[1]}\\ &=y_k^{[0]}+\omega_n^{k+(n/2)}y_k^{[1]}\\ &=A^{[0]}(\omega_n^{2k})+\omega_n^{k+(n/2)}A^{[1]}(\omega_n^{2k})\\ &=A^{[0]}(\omega_n^{2k+n})+\omega_n^{k+(n/2)}A^{[1]}(\omega_n^{2k+n})\\ &=A(\omega_n^{k+(n/2)}) \end{aligned}

我多项式乘积呢?

显然 A(x)B(x) 的 DFT 每一项的乘积组成的新向量就是 A(x)\cdot B(x) 的 DFT。(当然 n 要相等。)

那么下一个问题,就是通过 DFT 求原系数向量。就是说,要求逆 DFT(IDFT)。

逆 DFT(IDFT)

我们把 DFT 写成一个矩阵方程。

\begin{bmatrix} y_0\\ y_1\\ y_2\\ \vdots\\ y_{n-1}\\ \end{bmatrix}= \begin{bmatrix} 1&1&1&1&\cdots &1\\ 1&\omega_n&\omega_n^2&\omega_n^3&\cdots &\omega_n^{n-1}\\ 1&\omega_n^2&\omega_n^4&\omega_n^6&\cdots&\omega_n^{2(n-1)}\\ \vdots&\vdots & \vdots & \vdots & \ddots & \vdots \\ 1&\omega_n^{n-1}&\omega_n^{2(n-1)}&\omega_n^{3(n-1)}&\cdots &\omega_n^{(n-1)(n-1)}\\ \end{bmatrix} \begin{bmatrix} a_0\\ a_1\\ a_2\\ \vdots\\ a_{n-1}\\ \end{bmatrix}

我们发现等号右面的左面(?)那个大矩阵又是一个 Vandermonde 矩阵。我们发现要求 a 只需要求 y 乘以它的逆矩阵。

你自己捯饬一下就会发现它的逆矩阵长这样:

1&1&1&1&\cdots &1\\ 1&\omega_n^{-1}&\omega_n^{-2}&\omega_n^{-3}&\cdots &\omega_n^{-(n-1)}\\ 1&\omega_n^{-2}&\omega_n^{-4}&\omega_n^{-6}&\cdots&\omega_n^{-2(n-1)}\\ \vdots&\vdots & \vdots & \vdots & \ddots & \vdots \\ 1&\omega_n^{-(n-1)}&\omega_n^{-2(n-1)}&\omega_n^{-3(n-1)}&\cdots &\omega_n^{-(n-1)(n-1)}\\ \end{bmatrix}

(没捯饬出来的你自己核实)

于是可以看出,a_k=\frac{1}{n}\sum_{j=0}^{n-1}y_j\omega_n^{-kj}。看起来和 DFT 的式子很像。

没错,要想计算逆 DFT,只需要把 FFT 里头的 \omega_n 改成 \omega_n^{-1}=\cos(2\pi/n)-i\sin{(2\pi/n)}(你自己核实),然后最后求出来以后整个除以 n 即可。

实际上书写的时候,不用写那么多函数,只需要给予一个参数 o 来表明要求 DFT 还是逆 DFT。

#include <iostream>
#include <cstdio>
#include <vector>
#include <cmath>
const int N = 1e6 + 7;
const long double pi = 3.14159265358979323846;
class complex {
    public:
        long double a;
        long double b;
        friend complex operator + (complex x, complex y) {
            return {x.a + y.a, x.b + y.b};
        }
        friend complex operator - (complex x, complex y) {
            return {x.a - y.a, x.b - y.b};
        }
        friend complex operator * (complex x, complex y) {
            return {x.a * y.a - x.b * y.b, x.a * y.b + x.b * y.a};
        }
};
std::vector <complex> FFT(std::vector <complex> a, int n, int o) {
    if(n == 1) return a;
    complex Wn = {std::cos(2 * pi / n), o * std::sin(2 * pi / n)};
    complex w = {1.0, 0};
    std::vector <complex> a0(n >> 1), a1(n >> 1);
    for(int i = 0; i < (n >> 1); i++)
        a0[i] = a[i << 1], a1[i] = a[(i << 1) | 1];
    std::vector <complex> y0 = FFT(a0, n >> 1, o);
    std::vector <complex> y1 = FFT(a1, n >> 1, o);
    std::vector <complex> y(n);
    for(int i = 0; i < (n >> 1); i++, w = w * Wn) {
        y[i] = y0[i] + y1[i] * w;
        y[i + (n >> 1)] = y0[i] - y1[i] * w;
    }
    return y;
}
int main() {
    int n, m; //注意这里变成了输入次数 
    std::cin >> n >> m;
    int l = 1;
    while(l <= n + m) l <<= 1;
    std::vector <complex> a(l), b(l);
    for(int i = 0; i <= n; i++) {
        int t = read();
        a[i] = {(long double)t, 0};
    }
    for(int i = 0; i <= m; i++) {
        int t = read();
        b[i] = {(long double)t, 0};
    }
    std::vector <complex> fa = FFT(a, l, 1);
    std::vector <complex> fb = FFT(b, l, 1);
    std::vector <complex> fc(l);
    for(int i = 0; i < l; i++) fc[i] = fa[i] * fb[i];
    std::vector <complex> c = FFT(fc, l, -1);
    for(int i = 0; i <= n + m; i++) printf("%d ", (int)(c[i].a / l + 0.5));
    return 0;
}

还能更快吗?

要是把这个代码交到 P3803,最后几个点一定会 TLE。究其原因,是常数过大导致的(你自己核实)。

以下介绍一种常数较小的方法。

迭代实现

首先上述代码中计算了两次 \omega_n^ky_k^{[1]}。我们可以只计算一次:

for(int i = 0; i < (n >> 1); i++) {
    complex t = y1[i] * w;
    y[i] = y0[i] + t;
    y[i + (n >> 1)] = y0[i] - t;
    w = w * Wn;
}

这个循环内的操作叫做蝴蝶操作

接下来,我们刚才实际上用一个树形结构来“分解”输入的系数向量。

举个例子:

(a_0,a_2,a_4,a_6)(a_1,a_3,a_5,a_7)\\ (a_0,a_4)(a_2,a_6)(a_1,a_5)(a_3,a_7)\\ (a_0)(a_4)(a_2)(a_6)(a_1)(a_5)(a_3)(a_7)

稀碎。

为了不去递归,我们可以按上面最后一行中的顺序,在数组中存储所有的 a_i,接下来自底向上合并 DFT,直到产生最终结果。合并的操作和递归实现的差不多,但是使用蝴蝶操作。

问题来了,最底下那是什么顺序?

我们会发现一个不太平凡的结论:第 k 位所存储的 a 的下标,就是 k\log_2n 位二进制表示后翻转过来得到的数值。

举个例子:

翻转之后就是 $000,100,010,110,001,101,011,111$, 也就是 $0,4,2,6,1,5,3,7$。 理解和证明留给读者完成,提示:像 Huffman 编码那样给树形结构的边附上 $0,1$ 边权。 计算可以用递推的方式来计算: ```cpp for(int i = 0; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); // ``` 其中 $k=\log_2n$。 这里给出迭代实现 FFT 的 C++ 代码: ```cpp std::vector <complex> copy(std::vector <complex> a) { std::vector <complex> aa(n); for(int i = 0; i < n; i++) aa[rev[i]] = a[i]; return aa; } std::vector <complex> FFT(std::vector <complex> a) { std::vector <complex> aa = copy(a); for(int s = 1; s < n; s <<= 1) { complex Wn = {cos(pi / s), sin(pi / s)}; for(int i = 0; i < n; i += (s << 1)) { complex w = {1, 0}; for(int j = 0; j < s; j++) { //蝴蝶操作 complex t = w * aa[i + j + s]; complex u = aa[i + j]; aa[i + j] = u + t; aa[i + j + s] = u - t; w = w * Wn; } } } return a; } ``` 可以自己试着写出逆 FFT 的代码或伪代码。 以下是一个更简洁的多项式乘法代码: ```cpp #include <iostream> #include <cstdio> #include <cmath> #include <vector> const int N = 1 << 21; const long double pi = 3.14159265358979323846264338; int n, m, l = 1; class complex { public: long double a, b; friend complex operator + (complex x, complex y) { return {x.a + y.a, x.b + y.b}; } friend complex operator - (complex x, complex y) { return {x.a - y.a, x.b - y.b}; } friend complex operator * (complex x, complex y) { return {x.a * y.a - x.b * y.b, x.a * y.b + y.a * x.b}; } }; int rev[N]; std::vector <complex> FFT(std::vector <complex> &a, int o) { for(int i = 0; i < l; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]); for(int s = 1; s < l; s <<= 1) { complex Wn = {cos(pi / s), o * sin(pi / s)}; for(int i = 0; i < l; i += (s << 1)) { complex w = {1, 0}; for(int j = 0; j < s; j++) { //蝴蝶操作 complex t = w * a[i + j + s]; complex u = a[i + j]; a[i + j] = u + t; a[i + j + s] = u - t; w = w * Wn; } } } return a; } int main() { std::cin >> n >> m; int k = 0; while(l <= n + m) l <<= 1, k++; std::vector <complex> a(l), b(l); for(int i = 0; i <= n; i++) std::cin >> a[i].a; for(int i = 0; i <= m; i++) std::cin >> b[i].a; for(int i = 0; i < l; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); FFT(a, 1), FFT(b, 1); for(int i = 0; i < l; i++) a[i] = a[i] * b[i]; FFT(a, -1); for(int i = 0; i <= n + m; i++) printf("%d ", (int)(a[i].a / l + 0.5)); return 0; } ```