NTT 学习笔记

· · 个人记录

Pre Knowledge

方程 a^x \equiv 1 \pmod m 的最小正整数解称为 a 在模 m 意义下的阶,记作 \text{ord}_m(a)

原根

对正整数 m,若存在正整数 g,使得 \text{ord}_m(g) = \varphi(m),则称 g 是模 m 意义下的原根,或者简称为 gm 的原根。

FFT to NTT

由于 FFT 的复数运算在某些时候并不好用,考虑在另一些域里找到平替。

考虑一个模质数 p 的域,此时若 n \mid (p-1),就存在本原 n 次方根(即 FFT 中的 \omega_n),由于 n=2^k,这时我们需要一个形如 q2^t+1 形的质数 p(比如最经典的 998244353=7\times17\times2^{23}+1),此时的原根 g 满足 g^{qn}\equiv 1\pmod p,取 g_n = g^q 即可看做 FFT 中的 \omega_n,他们有一些相似的性质,比如 g_n^n \equiv 1 \pmod pg_n^{\frac{n}{2}} \equiv -1 \pmod p

NTT 过程

几乎和 FFT 相同,直接看代码吧。

#include <bits/stdc++.h>
using namespace std;
const int N = 2.2e6 + 5, mod = 998244353, g = 3, gi = 332748118;
int n, m, len = 1, o, p[N], a[N], b[N];
int ksm(int a, int b)
{
    int ret = 1;
    while (b)
    {
        if (b & 1) ret = 1ll * ret * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ret;
}
void ntt(int x[], int tp)
{
    for (int i = 0; i < len; i++) if (i < p[i]) swap(x[i], x[p[i]]);
    for (int mid = 1; mid < len; mid <<= 1)
    {
        int w = ksm(tp == 1 ? g : gi, (mod - 1) / (mid << 1)); 
        for (int R = mid << 1, j = 0; j < len; j += R) 
        {
            int W = 1;
            for (int k = 0; k < mid; k++, W = 1ll * W * w % mod)
            {
                int c1 = x[j+k], c2 = 1ll * W * x[j+k+mid] % mod;
                x[j+k] = (c1 + c2) % mod; x[j+k+mid] = (c1 - c2 + mod) % mod;
            }
        }
    }
}
int main()
{
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for (int i = 0; i <= n; i++) cin >> a[i];
    for (int j = 0; j <= m; j++) cin >> b[j];
    while (len <= n + m) len <<= 1, o++;
    for (int i = 0; i < len; i++) p[i] = (p[i>>1] >> 1) | ((i & 1) << (o - 1));
    ntt(a, 1); ntt(b, 1);
    for (int i = 0; i <= len; i++) a[i] = 1ll * a[i] * b[i] % mod;
    ntt(a, -1);
    for (int i = 0; i <= n + m; i++) cout << 1ll * a[i] * ksm(len, mod - 2) % mod << " ";
    return 0;
}

多项式乘法逆

问题:给定多项式 f(x),求 g(x),满足 f(x)* g(x) \equiv 1 \pmod {x^n},系数对 998244353 取模。

考虑迭代,假设我们已经求出了 f(x)x^{\frac{n}{2}} 的逆 g'(x),考虑 f(x)x^n 的逆 g(x)

推导一下:

f(x)g'(x) \equiv 1 \pmod {x^{\frac{n}{2}}}\\ f(x)g(x) \equiv 1 \pmod {x^{\frac{n}{2}}}\\ \therefore g'(x)-g(x) \equiv 0 \pmod {x^{\frac{n}{2}}}\\ \therefore (g'(x)-g(x))^2 \equiv 0 \pmod {x^n}\\ g'(x)^2 -2g'(x)g(x) + g(x)^2 \equiv 0 \pmod {x^n}\\ f(x)g'(x)^2 -2g'(x) + g(x) \equiv 0 \pmod {x^n}\\ g(x) \equiv 2g'(x)-f(x)g'(x)^2 \pmod {x^n}\\

于是 g'(x) 就可以迭代到 g(x),使用 NTT 进行多项式乘法,复杂度 T(n) = T(\frac{n}{2}) + O(n \log n) = O(n \log n)

#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 5, mod = 998244353, g = 3, gi = 332748118;
int n, nn, m, p[N], a[N];
int ksm(int a, int b)
{
    int ret = 1;
    while (b)
    {
        if (b & 1) ret = 1ll * ret * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ret;
}
void ntt(int x[], int tp)
{
    for (int i = 0; i < n; i++) if (i < p[i]) swap(x[i], x[p[i]]);
    for (int mid = 1; mid < n; mid <<= 1)
    {
        int w = ksm(tp == 1 ? g : gi, (mod - 1) / (mid << 1)); 
        for (int R = mid << 1, j = 0; j < n; j += R) 
        {
            int W = 1;
            for (int k = 0; k < mid; k++, W = 1ll * W * w % mod)
            {
                int c1 = x[j+k], c2 = 1ll * W * x[j+k+mid] % mod;
                x[j+k] = (c1 + c2) % mod; x[j+k+mid] = (c1 - c2 + mod) % mod;
            }
        }
    }
    if (tp == -1)
    {
        int inv = ksm(n, mod - 2);
        for (int i = 0; i < n; i++) x[i] = 1ll * x[i] * inv % mod;
    }
}
int X[N], Y[N];
void mul(int x[], int y[])
{
    for (int i = 0; i < (n << 1); i++) X[i] = Y[i] = 0;
    for (int i = 0; i < (n >> 1); i++) X[i] = x[i], Y[i] = y[i];
    ntt(X, 1); ntt(Y, 1);
    for (int i = 0; i < n; i++) X[i] = 1ll * X[i] * Y[i] % mod;
    ntt(X, -1);
    for (int i = 0; i < n; i++) x[i] = X[i];
}
int b[20][N]; // 每次的 g(x)
int main()
{
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin >> n; n--; nn = n << 1;
    for (int i = 0; i <= n; i++) cin >> a[i];
    b[0][0] = ksm(a[0], mod - 2); n = 4;
    int i = 0;
    for (int len = 1, o = 1; len < nn; o++, len <<= 1, n <<= 1)
    {
        i++;
        for (int j = 0; j < n; j++) p[j] = (p[j>>1] >> 1) | ((j & 1) << o);
        for (int j = 0; j <= len; j++) b[i][j] = (b[i-1][j] << 1) % mod;
        mul(b[i-1], b[i-1]); mul(b[i-1], a);
        for (int j = 0; j <= len; j++) b[i][j] = (b[i][j] - b[i-1][j] + mod) % mod;
    }
    for (int j = 0; j <= (nn >> 1); j++) cout << b[i][j] << " ";
    return 0;
}