NTT 学习笔记
Pre Knowledge
阶
方程
原根
对正整数
FFT to NTT
由于 FFT 的复数运算在某些时候并不好用,考虑在另一些域里找到平替。
考虑一个模质数
NTT 过程
几乎和 FFT 相同,直接看代码吧。
#include <bits/stdc++.h>
using namespace std;
const int N = 2.2e6 + 5, mod = 998244353, g = 3, gi = 332748118;
int n, m, len = 1, o, p[N], a[N], b[N];
int ksm(int a, int b)
{
int ret = 1;
while (b)
{
if (b & 1) ret = 1ll * ret * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ret;
}
void ntt(int x[], int tp)
{
for (int i = 0; i < len; i++) if (i < p[i]) swap(x[i], x[p[i]]);
for (int mid = 1; mid < len; mid <<= 1)
{
int w = ksm(tp == 1 ? g : gi, (mod - 1) / (mid << 1));
for (int R = mid << 1, j = 0; j < len; j += R)
{
int W = 1;
for (int k = 0; k < mid; k++, W = 1ll * W * w % mod)
{
int c1 = x[j+k], c2 = 1ll * W * x[j+k+mid] % mod;
x[j+k] = (c1 + c2) % mod; x[j+k+mid] = (c1 - c2 + mod) % mod;
}
}
}
}
int main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n >> m;
for (int i = 0; i <= n; i++) cin >> a[i];
for (int j = 0; j <= m; j++) cin >> b[j];
while (len <= n + m) len <<= 1, o++;
for (int i = 0; i < len; i++) p[i] = (p[i>>1] >> 1) | ((i & 1) << (o - 1));
ntt(a, 1); ntt(b, 1);
for (int i = 0; i <= len; i++) a[i] = 1ll * a[i] * b[i] % mod;
ntt(a, -1);
for (int i = 0; i <= n + m; i++) cout << 1ll * a[i] * ksm(len, mod - 2) % mod << " ";
return 0;
}
多项式乘法逆
问题:给定多项式
考虑迭代,假设我们已经求出了
推导一下:
于是
#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 5, mod = 998244353, g = 3, gi = 332748118;
int n, nn, m, p[N], a[N];
int ksm(int a, int b)
{
int ret = 1;
while (b)
{
if (b & 1) ret = 1ll * ret * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ret;
}
void ntt(int x[], int tp)
{
for (int i = 0; i < n; i++) if (i < p[i]) swap(x[i], x[p[i]]);
for (int mid = 1; mid < n; mid <<= 1)
{
int w = ksm(tp == 1 ? g : gi, (mod - 1) / (mid << 1));
for (int R = mid << 1, j = 0; j < n; j += R)
{
int W = 1;
for (int k = 0; k < mid; k++, W = 1ll * W * w % mod)
{
int c1 = x[j+k], c2 = 1ll * W * x[j+k+mid] % mod;
x[j+k] = (c1 + c2) % mod; x[j+k+mid] = (c1 - c2 + mod) % mod;
}
}
}
if (tp == -1)
{
int inv = ksm(n, mod - 2);
for (int i = 0; i < n; i++) x[i] = 1ll * x[i] * inv % mod;
}
}
int X[N], Y[N];
void mul(int x[], int y[])
{
for (int i = 0; i < (n << 1); i++) X[i] = Y[i] = 0;
for (int i = 0; i < (n >> 1); i++) X[i] = x[i], Y[i] = y[i];
ntt(X, 1); ntt(Y, 1);
for (int i = 0; i < n; i++) X[i] = 1ll * X[i] * Y[i] % mod;
ntt(X, -1);
for (int i = 0; i < n; i++) x[i] = X[i];
}
int b[20][N]; // 每次的 g(x)
int main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n; n--; nn = n << 1;
for (int i = 0; i <= n; i++) cin >> a[i];
b[0][0] = ksm(a[0], mod - 2); n = 4;
int i = 0;
for (int len = 1, o = 1; len < nn; o++, len <<= 1, n <<= 1)
{
i++;
for (int j = 0; j < n; j++) p[j] = (p[j>>1] >> 1) | ((j & 1) << o);
for (int j = 0; j <= len; j++) b[i][j] = (b[i-1][j] << 1) % mod;
mul(b[i-1], b[i-1]); mul(b[i-1], a);
for (int j = 0; j <= len; j++) b[i][j] = (b[i][j] - b[i-1][j] + mod) % mod;
}
for (int j = 0; j <= (nn >> 1); j++) cout << b[i][j] << " ";
return 0;
}