从零开始的 FFT

· · 算法·理论

更好的阅读体验

概述

离散傅里叶变换(Discrete Fourier Transform,缩写为 DFT),是傅里叶变换在时域和频域上都呈离散的形式,将信号的时域采样变换为其 DTFT 的频域采样。

FFT 是一种高效实现 DFT 的算法,称为快速傅立叶变换(Fast Fourier Transform,FFT)。它对傅里叶变换的理论并没有新的发现,但是对于在计算机系统或者说数字系统中应用离散傅立叶变换,可以说是进了一大步。快速数论变换 (NTT) 是快速傅里叶变换(FFT)在数论基础上的实现。

在 1965 年,Cooley 和 Tukey 发表了快速傅里叶变换算法。事实上 FFT 早在这之前就被发现过了,但是在当时现代计算机并未问世,人们没有意识到 FFT 的重要性。一些调查者认为 FFT 是由 Runge 和 König 在 1924 年发现的。但事实上高斯早在 1805 年就发明了这个算法,但一直没有发表。

作用

Q:FFT 有什么用?

A:在信息学竞赛上主要用于多项式乘法。

前置知识

多项式

先看 百度百科

定义

几个由数和字母的积组成的代数式之和为多项式。对于 f(x) = \sum^{n}_{i = 0}a_i\times x^i 是一个关于 xn 次多项式。

表示

系数表示

根据定义,若 f(x) = \sum^{n}_{i = 0}a_i\times x^i 可以用一个以 a 为坐标的 n+1 维向量表示,这种表示方法可以称为系数表示

eg:有 A(x)=5+7x+39x^2+4x^5 可表示为 \begin{bmatrix}5 \\ 7 \\39 \\0 \\4\end{bmatrix}

点值表示

发现,关于 x 多项式给以看成一个关于 x 的函数。在初中我们就学过一个 n 次函数可以用 n+1 个过此函数的点表示。

eg:有 A(x)=5+7x+39x^2+4x^5 可表示为 \{(-2.095, 0),\ (-0.089, 4.685),\ (0, 5),\ (0.22,8.487), \ (0.258, 9.406),\ (-2.103, -1.832) \}

证明:待定系数,带入 n + 1 个点,可解出系数。

运算

如下(A = \sum^{n}_{i = 0}a_i\times x^iB = \sum^{m}_{i = 0}b_i\times x^i):

\sum^{\max\{n,m\}}_{i = 0}(a_i+b_i)\times x^i & op = +\\ \sum^{\max\{n,m\}}_{i = 0}(a_i-b_i)\times x^i & op = -\\ \sum^{n}_{i = 0} \sum^{m}_{j = 1}a_j\times b_{i - j}\times x^{i + j} & op = \times\\ \Large{\frac{ \sum^{n}_{i = 0}a_i\times x^i}{\sum^{m}_{i = 0}b_i\times x^i}} & op = \div \end{matrix}\right.

多项式计算律同整数。

复数

先看 百度百科

定义1

定义 i = \sqrt{-1} ,形如 z=a+b\times i 的数为负数,集合符合为 \mathbb{C} 。其中 a 为实部,b 为虚部。

z_1=a+b\times iz_2=a-b\times i 两复数实部相同虚部相反则称 z_1z_2 共轭,z_1z_2 可分别记为 \overline{z_2}\overline{z_1}

z=a+b\times i 复数到原点的距离为复数的模长,记作 |z|=\sqrt{a^2+b^2}

辐角为负数与实数轴正方向的夹角。

表示

代数表示

如定义如 z=a+b\times i

极坐标表示

使用模长加辐角表示如 z=(a,\theta)=acos\theta +asin\theta\times i

指数表示

ae^{i\theta } = (a,\theta)

运算

如下 z_1 = a + b\times iz_2 = c + d \times i

(a + c) + (b + d) \times i & op = +\\ (a - c) + (b - d) \times i& op = -\\ (ac - bd) + (bc + ad) \times i & op = \times\\ \Large\frac{(ac + bd)+(bc - ad)\times i}{c^2 + d ^2} & op = \div \end{matrix}\right.

复数计算律同整数。

定义2

方程 x^n=1 的解是单位根,记为 \omega_n^kk\in\mathbb{Z}\cap [1, n]),得 \omega_n^{k-1} = (1, \frac{2\pi k}{n}) = e^{\frac{2\pi ki}{n}} = \cos\frac{2\pi k}{n}+isin\frac{2\pi k}{n}

eg:n=3 的单位根集合为 \{1,\frac{-1+\sqrt{3}i}{2},\frac{-1-\sqrt{3}i}{2}\}

单位根的性质

折半性质:\omega^{2k}_{2n} = \omega^{k}_{n}

证明:\omega^{2k}_{2n} = \cos\frac{2\pi \times 2k}{2n}+isin\frac{2\pi \times 2k}{2n}=\cos\frac{2\pi k}{n}+isin\frac{2\pi k}{n} = \omega^{k}_{n}

幂运算性质:(\omega^k_{2n})^2=\omega^k_{n}

证明:(\omega^k_{2n})^2 = e^{\frac{2\pi i \times 2k}{2n}} = e^{\frac{2\pi i \times k}{n}}=\omega^k_{n}

负共轭对称性:\omega^{n+k}_{2n}=-\omega^k_{2n}

证明:\omega^{n+k}_{2n} = e^{\frac{2\pi i(n+ k)}{2n}}=e^{2\pi i}+e^{\frac{2\pi k}{n}} =-\omega^k_{2n}

正文

终于学完前置知识,下面切入正题。

朴素算法的多项式乘法直接套用定义,时间复杂度为 O(n^2),这导致当 n 很大如P3803 【模板】多项式乘法(FFT) - 洛谷时就挂了。

考虑使用不同于普通的点值表示法,因为点值表示乘法时间复杂度为 O(n) ,方法为把每个对应的点相乘。如何求点值?最简单的想法是带入 n+1 个点的 x 坐标算出 y 坐标,但是时间复杂度退化成了 O(n^2),由此一个天才般的算法 FFT 就产生了。

推导

FFT

现在有一个多项式 A(x) = \sum^{n}_{i = 1}a_i\times x^i。考虑对其按 x 次数奇偶分类得到 A(x)=\sum^{\frac{n}{2}}_{i = 0}a_{2i}\times x^{2i} + \sum^{\frac{n}{2}}_{i = 0}a_{2i+1}\times x^{2i+1}

A_1(x)=\sum^{\frac{n}{2}}_{i = 0}a_{2i}\times x^{i}A_2(x)=\sum^{\frac{n}{2}}_{i = 0}a_{2i+1}\times x^{i+1},则:A(x)=A_1(x^2)+xA_2(x^2)

代入 x = \omega^k_n (k<\frac{n}{2}) 得:A(\omega^k_n) = A_1(\omega^{2k}_n) + \omega^k_n A_2(\omega^{2k}_n) = A_1(\omega^k_{n/2}) + \omega^k_n A_2(\omega^k_{n/2})

代入 x = \omega^{k+\frac{n}{2}}_n 得:A(\omega_n^{k+\frac{n}{2}}) = A_1(\omega_n^{2k+n}) + \omega_n^{k+\frac{n}{2}}(\omega_n^{2k+n})= A_1(\omega_n^{2k} \cdot \omega_n^n) - \omega_n^k A_2(\omega_n^{2k} \cdot \omega_n^n) = A_1(\omega_n^{2k}) - \omega_n^k A_2(\omega_n^{2k})

有没有发现什么?

这两个式子只有常数项相同,所以我们只要计算第一个式子就可以顺便求出第二个式子的值。

可发现一二两式范围相同、无重叠、覆盖整个求解区间,故可分治。

时间复杂度:\Theta (n\log n)

DFT

转点值表示后还要再把起转化成系数表示。

发现之前的计算就是进行了如下矩阵乘法。

\begin{pmatrix} (w_n^0)^0 & (w_n^1)^0 & \cdots & (w_n^{n-1})^0 \\ (w_n^0)^1 & (w_n^1)^1 &\cdots & (w_n^{n-1})^1 \\ \vdots & \vdots & \ddots & \vdots \\ (w_n^0)^{n-1} & (w_n^1)^{n-1} &\cdots & (w_n^{n-1})^{n-1} \end{pmatrix}\begin{pmatrix} a_0 \\ a_1 \\ \vdots \\ a_{n-1}\end{pmatrix}=\begin{pmatrix} A(w_n^0) \\ A(w_n^1) \\ \vdots \\ A(w_n^{n-1}) \end{pmatrix}

定义:D=\begin{pmatrix} (w_n^0)^0 & (w_n^1)^0 & \cdots & (w_n^{n-1})^0 \\ (w_n^0)^1 & (w_n^1)^1 &\cdots & (w_n^{n-1})^1 \\ \vdots & \vdots & \ddots & \vdots \\ (w_n^0)^{n-1} & (w_n^1)^{n-1} &\cdots & (w_n^{n-1})^{n-1} \end{pmatrix}V = \begin{pmatrix} a_0 \\ a_1 \\ \vdots \\ a_{n-1}\end{pmatrix}

对于 (D \times V)_{ij} = \sum_{k=0}^{k<n} d_{ik} \times v_{kj} = \sum_{k=0}^{k<n} w_n^{-ik} \times w_n^{kj} = \sum_{k=0}^{k<n} w_n^{k(j-i)}

i=j 时:\text{原式}=n

i\ne j 时:\text{原式}=0

\because \omega^n_n = 1 \therefore \frac{D}{n}=V^{-1}

带入原公式可得:

\begin{pmatrix} a_0 \\ a_1 \\ \vdots \\ a_{n-1} \end{pmatrix}=\frac{1}{n}\begin{pmatrix} (w_n^{-0})^0 & (w_n^{-1})^0 & \cdots & (w_n^{-(n-1)})^0 \\ (w_n^{-0})^1 & (w_n^{-1})^1 & \cdots & (w_n^{-(n-1)})^1 \\ \vdots & \vdots & \ddots & \vdots \\ (w_n^{-0})^{n-1} & (w_n^{-1})^{n-1} & \cdots & (w_n^{-(n-1)})^{n-1} \end{pmatrix}\begin{pmatrix} A(w_n^0) \\ A(w_n^1) \\ \vdots \\ A(w_n^{n-1}) \end{pmatrix}

对比可以发现,D 中的每一项都变成了倒数,故只要把单位根替换成倒数跑 FFT 在除 n 即可。

code(递归)

#include <bits/stdc++.h>

#include <complex> // 使用STL复数库替代自定义实现

using namespace std;

// 常量定义
const double PI = acos(-1); // 精确计算圆周率(比硬编码更可靠)
const int N = 1 << 21; // 最大处理长度2^21(约2百万项)
typedef complex < double > Comp; // 复数类型简写

Comp f[N], g[N]; // 存储多项式系数的复数数组
vector < int > rev; // 位逆序置换表

/**
 * 快速傅里叶变换(非递归优化版)
 * @param a 复数数组指针
 * @param n 变换长度(必须为2的幂)
 * @param op 变换方向:1=正向变换,-1=逆向变换
 */
void FFT(Comp * a, int n, int op) {
    // 第一步:位逆序置换(Cache优化关键)
    for (int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(a[i], a[rev[i]]); // 避免重复交换

    // 第二步:分层蝴蝶运算(现代CPU流水线友好)
    for (int len = 2; len <= n; len <<= 1) {
        Comp wn(cos(2 * PI / len), op * sin(2 * PI / len)); // 当前层的单位根
        for (int l = 0; l < n; l += len) { // 分块处理
            Comp w(1, 0);
            for (int k = l; k < l + len / 2; ++k) { // 蝶形运算
                Comp x = a[k], y = w * a[k + len / 2];
                a[k] = x + y; // 前半部分
                a[k + len / 2] = x - y; // 后半部分(利用共轭对称性)
                w *= wn; // 更新旋转因子
            }
        }
    }
}

int main() {
    ios::sync_with_stdio(false), cin.tie(0); // IO优化

    // 输入处理
    int n, m;
    cin >> n >> m; // 两个多项式的最高次项
    for (int i = 0; i <= n; ++i) cin >> f[i]; // 读入多项式A
    for (int i = 0; i <= m; ++i) cin >> g[i]; // 读入多项式B

    // 计算扩展长度(最近的2的幂)
    int lim = 1, l = 0;
    while (lim <= n + m) lim <<= 1, ++l; // lim=最终长度,l=二进制位数

    // 初始化位逆序表(时空权衡优化)
    rev.resize(lim);
    for (int i = 0; i < lim; ++i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1)); // 位运算魔术

    // 正向FFT(系数->点值)
    FFT(f, lim, 1);
    FFT(g, lim, 1);

    // 点值相乘(O(n)复杂度核心)
    for (int i = 0; i < lim; ++i) f[i] *= g[i];

    // 逆向FFT(点值->系数)
    FFT(f, lim, -1);

    // 结果输出(注意精度处理)
    for (int i = 0; i <= n + m; ++i)
        cout << (int)(fabs(f[i].real()) / lim + 0.5) << " "; // 四舍五入

    return 0;
}

优化

递推优化

递归实现的 FFT 常数巨大。所以考虑改成递推版。

转移位置

若要递推自然就需要知道转移到哪里。我们来手模一下 8 项式拆分过程:

你发现了什么规律吗?

其实就是原来的那个序列,每个数用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标。比如 x_1001,翻转是 100,也就是 4,而且在最后的位置确实是 4。我们称这个变换为位逆序置换。

证明:对于长度为 N = 2^m 的序列,位逆序置换函数 \text{rev}(i)m 位二进制数 i = (b_{m-1}b_{m-2}\dots b_1b_0)_2 映射为其位反转结果 (b_0b_1\dots b_{m-2}b_{m-1})_2

m = 1 时:N = 2^1 = 2i \in \{0,1\}\text{rev}(0) = 0 = (0)_2\text{rev}(1) = 1 = (1)_2,命题成立。

假设当 m = k 时命题成立:对任意 i \in \{0, \dots, 2^k-1\},若 i = (b_{k-1}\dots b_0)_2,则 \text{rev}(i) = (b_0\dots b_{k-1})_2

m = k+1 时:设 i = (b_kb_{k-1}\dots b_0)_2,按最低位 b_0 划分为 i = 2jb_0=0)或 i = 2j+1b_0=1),其中 j = (b_k\dots b_1)_2 \in \{0, \dots, 2^k-1\}。由归纳假设,\text{rev}_k(j) = (b_1\dots b_k)_2\text{rev}_kk 位逆序函数)。左子树(i=2j)逆序结果为 (0b_1\dots b_k)_2,右子树(i=2j+1)逆序结果为 (1b_1\dots b_k)_2。合并后,\text{rev}(i) = (b_0b_1\dots b_k)_2,即 i 的位反转结果。

故对所有 m \geq 1,位逆序置换 \text{rev}(i)im 位二进制逆序。

由归纳步骤可推导出 \text {rev}(i) 的递推关系:

\begin{cases} \text{rev}(i/2) \ll 1 & \text{若 } i \text{ 为偶数}, \\ \text{rev}((i-1)/2) \ll 1 + 2^k & \text{若 } i \text{ 为奇数} \end{cases}$。 等价于位运算形式: $\text{rev}(i) = (\text{rev}(i \gg 1) \ll 1) \lor ((i \& 1) \ll k)

由此得到对应下标,之后向上同递归一样合并就可以了。

code

#include <bits/stdc++.h> 
using namespace std; 

typedef complex<double> Comp; 
const double PI = acos(-1); 
const int N = 1 << 21;  // 最大处理长度:2^21(约200万项) 

Comp f[N], g[N];        // 存储多项式系数的复数数组 
vector<int> rev;        // 位逆序置换表 

/** 
 * 快速傅里叶变换(非递归优化版) 
 * @param a     复数数组指针(输入多项式系数/输出点值) 
 * @param n     变换长度(必须为2的幂) 
 * @param op    变换方向:1=正向DFT,-1=逆向IDFT 
 */ 
void FFT(Comp* a, int n, int op) { 
    // 第一步:位逆序置换(Cache优化:仅交换i < rev[i]的元素) 
    for (int i = 0; i < n; ++i) { 
        if (i < rev[i]) { 
            swap(a[i], a[rev[i]]); 
        } 
    } 

    // 第二步:分层蝴蝶运算(自底向上合并子问题) 
    for (int len = 2; len <= n; len <<= 1) {  // len:当前合并的子序列长度 
        Comp wn(cos(2 * PI / len), op * sin(2 * PI / len));  // 单位根 w_len^1 
        for (int l = 0; l < n; l += len) {  // l:当前块的起始位置 
            Comp w(1, 0);  // 旋转因子初始化为 w_len^0 = 1 
            for (int k = l; k < l + len/2; ++k) {  // 对块内前半部分元素蝴蝶操作 
                Comp x = a[k]; 
                Comp y = w * a[k + len/2]; 
                a[k] = x + y;          // 前半部分结果 
                a[k + len/2] = x - y;  // 后半部分结果(利用对称性) 
                w *= wn;               // 更新旋转因子:w = w_len^(k+1) 
            } 
        } 
    } 
} 

int main() { 
    ios::sync_with_stdio(false); 
    cin.tie(0);   // IO优化:关闭同步流,加速输入输出 

    // 输入多项式A和B的系数 
    int n, m; 
    cin >> n >> m; 
    for (int i = 0; i <= n; ++i) cin >> f[i];  // 读入多项式A(次数n) 
    for (int i = 0; i <= m; ++i) cin >> g[i];  // 读入多项式B(次数m) 

    // 计算最小扩展长度(2的幂,需覆盖A*B的最高次n+m) 
    int lim = 1, bit = 0; 
    while (lim <= n + m) { 
        lim <<= 1;  // lim = 2^bit,其中bit为二进制位数 
        bit++; 
    } 

    // 初始化位逆序表(递推公式:rev[i] = (rev[i>>1]>>1) | ((i&1) << (bit-1))) 
    rev.resize(lim);  
    for (int i = 0; i < lim; ++i) { 
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); 
    } 

    // 正向FFT:系数表示 → 点值表示 
    FFT(f, lim, 1); 
    FFT(g, lim, 1); 

    // 点值相乘:(A*B)的点值 = A的点值 * B的点值(O(n)复杂度) 
    for (int i = 0; i < lim; ++i) { 
        f[i] *= g[i]; 
    } 

    // 逆向FFT:点值表示 → 系数表示(需除以lim归一化) 
    FFT(f, lim, -1); 

    // 输出结果:四舍五入取实部(消除浮点误差) 
    for (int i = 0; i <= n + m; ++i) { 
        cout << (int)(fabs(f[i].real()) / lim + 0.5) << " "; 
    } 

    return 0; 
} 

三次变两次优化

这里还有一种把总共执行 3 次的 FFT 改成 2 次。 我们可以把第一个多项式放在实部第二个放在虚部,求出平方,把虚部取出除 2 为答案。

设第一个多项式为 A,第二个为 B。 按操作得 (A+Bi)^2=(A^2 - B^2)+2ABi

\because 2ABi \div 2 = AB \therefore \text{得证}

code

#include <bits/stdc++.h>
using namespace std;

typedef complex<double> Comp;
const double PI = acos(-1);
const int N = 1 << 22;  // 最大支持 2^22 项(约400万)

Comp a[N];  // 合并存储多项式:实部=A(x),虚部=B(x)
vector<int> rev;  // 位逆序置换表

/**
* 快速傅里叶变换(非递归版)
* @param n 序列长度(2的幂)
* @param op 1=正向DFT,-1=逆向IDFT
*/
void fft(int n, int op) {
    // 位逆序置换(预交换)
    for (int i = 0; i < n; ++i) {
        if (i < rev[i]) swap(a[i], a[rev[i]]);
    }

    // 分层蝴蝶运算
    for (int len = 2; len <= n; len <<= 1) {
        Comp wn(cos(2 * PI / len), op * sin(2 * PI / len));
        for (int l = 0; l < n; l += len) {
            Comp w(1, 0);
            for (int k = l; k < l + len/2; ++k) {
                Comp x = a[k], y = w * a[k + len/2];
                a[k] = x + y;
                a[k + len/2] = x - y;
                w *= wn;
            }
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    int n, m;
    cin >> n >> m;

    // 读入多项式A(实部)和B(虚部)
    double val;  // 临时变量存储输入值
    for (int i = 0; i <= n; ++i) {
        cin >> val;
        a[i] = Comp(val, a[i].imag());  // 修改实部,保留原有虚部(初始为0)
    }
    for (int i = 0; i <= m; ++i) {
        cin >> val;
        a[i] = Comp(a[i].real(), val);  // 保留原有实部,修改虚部
    }

    // 计算最小长度(覆盖A*B的最高次n+m)
    int lim = 1, bit = 0;
    while (lim <= n + m) {
        lim <<= 1;
        bit++;
    }

    // 初始化位逆序表
    rev.resize(lim);
    for (int i = 0; i < lim; ++i) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }

    // 正向FFT:将A和B的系数同时转换为点值
    fft(lim, 1);

    // 点值相乘:(A + Bi)^2 = (A2 - B2) + 2ABi,虚部/2即为A*B的点值
    for (int i = 0; i < lim; ++i) {
        a[i] = a[i] * a[i];  // 复数平方
    }

    // 逆向FFT:将结果转换回系数表示(虚部/2为答案)
    fft(lim, -1);

    // 输出A*B的系数:虚部/2/lim(四舍五入)
    for (int i = 0; i <= n + m; ++i) {
        double res = a[i].imag() / 2.0 / lim;  // 提取虚部并归一化
        cout << (int)(fabs(res) + 0.5) << " ";
    }

    return 0;
}