FFT与NTT

· · 个人记录

本文为基础部分。

多项式进阶:多项式的高级运算

相似算法:快速沃尔什变换(FWT)

FFT与NTT用来处理多项式乘法。

快速傅里叶变换(FFT)

小学生都能看懂的FFT!!!

实质是加速“将单位根代入多项式得到点值表示”的非迭代分治做法。

DFT一遍以后数组第 i 位表示的是将 i 次单位根带入多项式的点值。

IDFT相当于将 -i 次单位根带入多项式的点值。根据单位根反演的相关知识,可以知道这样能够还原第 i 项的系数。

据此可以做:P4235 Hash?

还是详细来一遍吧:

DFT

设多项式:

F(x) = \sum_{i=0}^{n-1}a_ix^i

我们要快速求出 F(x)x = \omega_n^0, \omega_n^1,...,\omega_n^{n-1} 处的点值 集合 F'(x)

有:

F'(i) = F(w_n^i) = a_0 + a_1w_n^i + a_2w_n^{2i}...a_{n-1}w_n^{(n-1)i} =(a_0 + a_2w_n^{2i}...) + w_n^i(a_1 + a_3w_n^{2i}...) = F^0(w_n^{2i}) + w_n^iF^1(w_n^{2i}) =F^0(w_{n/2}^i) + w_n^iF^1(w_{n/2}^i)

然后可以扔到两个子区间做子问题了,得到左边的 F^0(w_{n/2}^i) 以及右边的 w_n^iF^1(w_{n/2}^i) 以后就可以算出当前的 F(w_n^i) 了。

IDFT

我们断言 F'(w_n^{-k}) = na_k

证明:

F'(x)=\sum_{i=0}^{n-1}a'_ix^i =\sum_{i=0}^{n-1}(\sum_{j=0}^{n-1}a_jw_n^{ij})x^i =\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}w_n^{ij}x^i

带入 w_n^{-k}

F(w_n^{-k})=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}w_n^{ij}w_n^{-ik} =\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}w_n^{i(j-k)}

由单位根反演知:

\sum_{i=0}^{n-1}w_n^{ix}=n[n|x]

因此:

F(w_n^{-k})=n\sum_{j=0}^{n-1}a_j[n|(j-k)] =n\sum_{t=0}^{\infty}a_{tn+k}

因此,当取值充足的前提下,F(w_n^{-k})=na_k;当实际项数超过 nlimi)的时候,F(w_n^{-k})=n(a_k+a_{k+t}+...),即循环卷积

循环卷积题

2020.12.23 Update:

如果IDFT时我们硬是吧 w_n^k 代入的话会发生什么?

\begin{aligned} F'(w_n^k) &= \sum_{i=0}^{n-1} \sum_{j=0}^{n-1}a_jw_n^{ij}w_n^{ik}m\\ &= \sum_{j=0}^{n-1} a_j \sum_{i=0}^{n-1} w_n^{i(k+j)}\\ &= \sum_{j=0}^{n-1} a_jn[n|(k+j)] \end{aligned}

在不发生循环卷积的情况下,代入 w_n^0 会得到 na_0,代入 w_n^i 会得到 na_{n-i}。于是除 n 后将 1...n-1 翻转一下即可。

```cpp struct Complex { double x, y; Complex(double xx = 0, double yy = 0) {x = xx, y = yy;} Complex operator + (const Complex &i) const { return Complex(x + i.x, y + i.y); } Complex operator - (const Complex &i) const { return Complex(x - i.x, y - i.y); } Complex operator * (const Complex &a) const { return Complex(x * a.x - y * a.y, x * a.y + y * a.x); } }A[N], B[N]; int n, m, limi = 1, l; int r[N]; const double Pi = 3.14159265358979323846264; void fft(Complex *a, int type) { for (register int i = 0; i < limi; ++i) if (i < r[i]) swap(a[i], a[r[i]]); for (register int j = 1; j < limi; j <<= 1) {//长度 Complex T(cos(Pi/j), type * sin(Pi / j)); for (register int k = 0; k < limi; k += (j << 1)) {//第几块 Complex t(1, 0); for (register int p = 0; p < j; ++p, t = t * T) {//该块的第几个 Complex nx = a[k + p], ny = t * a[k + j + p]; a[k + p] = nx + ny; a[k + j + p] = nx - ny; } } } } int main() { read(n); read(m); int aa; for (register int i = 0; i <= n; ++i) read(aa), A[i].x = aa; for (register int i = 0; i <= m; ++i) read(aa), B[i].x = aa; while (limi<=n + m) limi <<= 1, l++; for (register int i = 0; i < limi; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); fft(A, 1); fft(B, 1); for (register int i = 0; i <= limi; ++i) A[i] = A[i] * B[i]; fft(A, -1); for(register int i = 0; i < limi; ++i) A[i].x /= limi;//此处亦可 reverse(A + 1, A + limi) for (register int i = 0; i <= n + m; ++i) printf("%d ", (int)(A[i].x + 0.5)); return 0; } ``` ## 快速数论变换(NTT) [快速数论变换(NTT)小结](https://www.cnblogs.com/zwfymqz/p/8980809.html) [NTT(快速数论变换)用到的各种素数及原根](https://blog.csdn.net/hnust_xx/article/details/76572828) | 素数 | 原根 | | :----------: | :----------: | | 998244353 | 3 | | 3221225473(long long) | 5 | | 395 824 185 999 37 (3e13) | 5 | 能够用原根代替单位根是因为原根具有类似单位根的性质: 1. 也有“循环”的性质 2. 不会 3. 不知 4. ... 记得取模!! $2020.7.28$ $Update:$更新了代码 $Code:
const int P = 998244353;
const int G = 3;
const int Gi = (P + 1) / G;
inline void ntt(ll *a, int type) {
    for (register int i = 1; i < limi; ++i)
        if (i < r[i])   swap(a[i], a[r[i]]);
    for (register int i = 1; i < limi; i <<= 1) {//i < limi
        ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));//Attention!!
        for (register int j = 0; j < limi; j += (i << 1)) {
            ll t = 1;
            for (register int k = 0; k < i; ++k, t = t * T % P) {//Attention!! : % P
                ll nx = a[j + k], ny = a[j + k + i] * t % P;
                a[j + k] = (nx + ny) % P;
                a[j + k + i] = (nx - ny + P) % P;
            }
        }
    }
    if (type == -1) {
        ll inv = quickpow(limi, P - 2);
        for (register int i = 0; i < limi; ++i)
            a[i] = a[i] * inv % P;
    }
}
inline void mul(ll *a, ll *b, int n, int m) {//传入 a, b,导出到 a
    while (limi <= (n + m)) limi <<= 1, ++L;
    for (register int i = 1; i < limi; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
    ntt(a, 1), ntt(b, 1);
    for (register int i = 0; i < limi; ++i) a[i] = a[i] * b[i] % P;
    ntt(a, -1);
}

【模板】A*B Problem升级版(FFT快速傅里叶)

通过模拟乘法竖式,我们发现,高精乘其实就是在进行多项式乘法。这样的话我们可以用FFT或NTT来把它优化到nlogn。

Code:
#define P 998244353
#define G 3
#define Gi 332748118
char as[N], bs[N];
int n, m;
ll A[N], B[N], ans[N];
ll limi = 1, l, inv;
int r[N];
inline ll quickpow(ll x, ll k)...
inline void ntt(ll *a, int type) {
    for (register int i = 0; i <= limi; ++i) 
        if (i < r[i])   swap(a[i], a[r[i]]);
    for (register int i = 1; i < limi; i <<= 1) {
        ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));
        for (register int j = 0; j < limi; j += (i << 1)) {
            ll t = 1;
            for (register int p = 0; p < i; ++p, t = t * T % P) {
                ll nx = a[j + p], ny = t * a[j + p + i] % P;
                a[j + p] = (nx + ny) % P;
                a[j + p + i] = (nx - ny + P) % P;
            }
        }
    }
}
int main() {
    scanf("%s%s", as, bs);
    n = strlen(as) - 1;
    m = strlen(bs) - 1;
    ll ct = 0;
    for (register int i = n; i >= 0; --i) A[ct++] = as[i] - '0';
    ct = 0;
    for (register int i = m; i >= 0; --i)   B[ct++] = bs[i] - '0';
    while (limi <= n + m)   limi <<= 1, l++;
    for (register int i = 1; i <= limi; ++i) 
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    ntt(A, 1); ntt(B, 1);
    for (register int i = 0; i <= limi; ++i)    A[i] = A[i] * B[i] % P;
    ntt(A, -1);
    inv = quickpow(limi, P - 2);
    for (register int i = 0; i <= limi; ++i)
        ans[i] = A[i] * inv % P;
    limi += 5;
    for (register int i = 0; i <= limi; ++i)
        if (ans[i] >= 10) {
            ans[i + 1] += ans[i] / 10;
            ans[i] %= 10;
        }
    ll len = 1;
    for (register int i = limi; i >= 0; --i)
        if (ans[i]) break;
        else    len = i - 1;
    for (register int i = len; i >= 0; --i) {
        printf("%lld", ans[i]);
    }
    return 0;
}

卡常技巧

预处理单位根:

理论上是有空间 O(n) 的做法的,大概是二倍(乘四倍)空间,但是下面的写法更好写。

ll yuangen[18][N];
inline void ntt(ll *a, int limi, int type) {
    for (int i = 0; i < limi; ++i)
        if (i < r[i])   swap(a[i], a[r[i]]);
    for (int i = 1, ji = 0; i < limi; i <<= 1, ++ji) {
        ll* G = yuangen[ji];
        for (int j = 0; j < limi; j += (i << 1)) {
            for (int k = 0; k < i; ++k) {
                ll nx = a[j + k], ny = a[i + j + k] * G[k] % P;
                a[j + k] = (nx + ny) % P;
                a[j + k + i] = (nx - ny + P) % P;
            }
        }
    }
    if (type == -1) {
        ll inv = quickpow(limi, P - 2);
        for (int i = 0; i < limi; ++i)  a[i] = a[i] * inv % P;
        reverse(a + 1, a + limi);
    }
}
...
for (int ji = 0, i = 1; ji < 18; i <<= 1, ++ji) {
    ll* G = yuangen[ji];
    G[0] = 1;
    G[1] = quickpow(3, (P + 1) / (i << 1));
    for (int j = 2; j < i; ++j) G[j] = G[j - 1] * G[1] % P;
}

例题

通过数学推导,我们发现,要解决其中的旋转求最大的aibi的和的问题时,我们可以把它转化成求卷积(多项式乘法)后的后n项的最值问题,这里用NTT优化。但其实这道题主要还是难在数学推导的想法以及如何想到卷积。

Code:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <string>
#define N 300010
#define P 998244353
#define G 3
#define Gi 332748118
#define inf 992337203685477580ll
typedef long long ll;
template<typename T> inline void read(T &x) {
    x = 0; char c = getchar(); bool flag = false;
    while (!isdigit(c)) {if (c == '-') flag = true; c = getchar(); }
    while (isdigit(c)) {x = (x << 1) + (x << 3) + (c ^ 48); c = getchar(); }
    if (flag)   x = -x;
}
using namespace std;
ll n, m, limi = 1, l; 
ll x[N], y[N], r[N];
ll ans, sum, toans = inf;
inline ll quickpow(ll x, ll k) {
    ll res = 1;
    while (k) {
        if (k & 1)  res = res * x % P;
        x = x * x % P;
        k >>= 1;
    }
    return res;
}
inline void ntt(ll *a, int type) {
    for (register int i = 0; i <= limi; ++i)
        if (i < r[i])   swap(a[i], a[r[i]]);
    for (register int i = 1; i < limi; i <<= 1) {
        ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));
        for (register int j = 0; j < limi; j += (i << 1)) {
            ll t = 1;
            for (register int p = 0; p < i; ++p, t = t * T % P) {
                ll nx = a[j + p], ny = t * a[j + p + i] % P;
                a[j + p] = (nx + ny) % P;
                a[j + p + i] = (nx - ny + P) % P;
            }
        }
    }
    if (type == -1) {
        ll inv = quickpow(limi, P - 2);
        for (register int i = 0; i <= limi; ++i)
            a[i] = a[i] * inv % P;
    }
}
int main() {
    read(n); read(m);
    for (register int i = 1; i <= n; ++i) read(x[i]), x[i + n] = x[i];
    for (register int i = 1; i <= n; ++i)   read(y[i]);
    for (register int i = 1; i <= n; ++i) {
        ans += x[i] * x[i] + y[i] * y[i];
        sum += x[i] - y[i];
    }
    sum *= 2;
    for (register int i = -m; i <= m; ++i) {
        toans = min(toans, 1ll * n * i * i + sum * i);
    }
    ans += toans;

    reverse(y + 1, y + n + 1);
    while (limi <= 2 * n)   limi <<= 1, l++;
    for (register int i = 0; i < limi; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    ntt(x, 1); ntt(y, 1);
    for (register int i = 0; i < limi; ++i) x[i] = x[i] * y[i] % P;
    ntt(x, -1);
    sum = 0;
    for (register int i = n + 1; i <= (n << 1); ++i)    sum = max(sum, x[i]);
    ans -= 2 * sum;
    printf("%lld\n", ans);
    return 0;
}
  1. 记得取模!+1

  2. 左移和右移一定分清!!

  3. 关于i = 0还是i = 1:

FFT和NTT里都是i = 0,别写成i = 1。

  1. 关于<= limi还是< limi:

写<= limi总不会错的。

统计答案的时候不要写<= limi!!!

第一层循环也不要写 <= limi,写 < limi

  1. 到了后面(多项式乘法时)n和m的出现次数就少了,主要是limi。

  2. cosnt int Gi = (M + 1) / G;以后就这么写吧,省着把332748118 写成 322748118

  3. NTT和FFT的第三层循环中的p应写成(int p = 0; p < i; ++p, t = t × T % P)。 +1

  4. 记住,是ax = a[j + p], ay = t × a[i + j + p]!!!别忘了乘t!!

  5. NTT和FFT的第一层循环应写成(int i = 1; i < limi; i <<= 1)。

  6. FFT中T为Complex(cos(PI / i), sin(PI / i) * type),横坐标是cos,纵坐标是sin!!

  7. 一开始蝴蝶变换的时候是swap(a[i], a[r[i]]),不是swap(i, r[i])!! +1

  8. ntt/fft 最终的除法操作是 type == -1 的时候做的!!!不是 type == 1!!!(真想不到还能这样出错)

习题

实际上这道题应该是例题的基础,是纯的FFT。

NTT配合manacher来做。细节不少,有一定难度。