原根与 NTT

· · 算法·理论

如何得到 n 次单位根

在 FNTT 中,我们并不是用原根来代替单位根,而是利用原根的性质,得到 n 次单位根。

考虑 FTT 中我们需要的单位根具有什么性质:

  1. 我们需要知道单位根 \omega_n^0 以及 \omega_n^1
  2. 每一次用当前的单位根 \omega_n^i 乘上 \omega_n^1 即可得到 \omega_n^{i+1}
  3. 进行 n 次操作之后恰好回到 \omega_n^0
  4. 这些单位根互不相同,否则点值会重复,插值会错误。
  5. 推导结论时需要用到的一些结论
    1. \omega_n^i\times\omega_n^j=\omega_n^{(i+j)\mod n}
    2. \omega_n^{2i}=\omega_{\frac{n}{2}}^{i}
    3. \omega_n^i+\omega_n^{i+\frac{n}{2}}=0

xp 意义下的阶恰好为 n。根据阶的定义 x^{i},0\le i<n 互不相同,且 x^{n}\equiv 1。那么我们设 \omega_{n}^{i}=x^{i}\bmod p,容易验证,除 5.3 以外上述性质全部被满足。如何找到满足 xp 意义下阶为 n 的数呢?

由原根的性质,若模 p 意义下有原根 g,则 g^{k} 的阶数为 \frac{p-1}{(k,p-1)}。我们希望 g^{k} 的阶数恰好为 n,这样就能满足我们上述性质了。

也就是说,n=\frac{p-1}{(k,p-1)},容易知道,当 n 不是 p-1 的因子时无解;否则,取 k=\frac{p-1}{n},就能得到 n 了。

所以,我们只需要知道 p 的一个原根,就能得到阶数为任意一个 p-1 的因子的数 x,进而用 x 代替单位根。

性质 5.3 由我们计算方法可以得到。

例如 p=998244353 时,p-1=2^{23}\times 7\times 17,此时 n=2^{k},k\le 23 的所有长度的 n 都存在 n 次单位根,我们就可以解决长度不超过 2^{23}=8388608 的多项式乘法了。

常用原根

$1945555039024054273=27\times 2^{56}+1$, $g=5$, $4179340454199820289=29\times 2^{57}+1$, $g=3$. 后两个模数是使用 FFT 常数过大,没有模数,且结果不会爆 `long long` 时,对这两个模数取模做 NTT 可以加速卷积,但是中间结果可能会爆 `long long`,需要开 `__int128`。 ## 模板代码 ```c++ const int N = 1 << 21; const int mod = 998244353, g = 3; using LL = long long; auto qpow = [](LL a, LL b) { LL res = 1; for (; b; b >>= 1, a = a * a % mod) if (b & 1) res = res * a % mod; return res; }; auto mul = [](LL *a, LL *b, LL *c, int n) { static int tr[N]; for (int i = 0; i < n; ++i) tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0); auto NTT = [](LL *a, int n, bool idft) { for (int i = 0; i < n; ++i) if (i < tr[i]) swap(a[i], a[tr[i]]); for (int len = 2; len <= n; len <<= 1) { int l = len >> 1; LL chg = qpow(g, (mod - 1) / len); if (idft) chg = qpow(chg, mod - 2); for (int k = 0; k < n; k += len) { LL rt = 1; for (int j = k; j < k + l; ++j) { LL tmp = a[j + l] * rt % mod; a[j + l] = a[j] - tmp + mod; if (a[j + l] >= mod) a[j + l] -= mod; a[j] = a[j] + tmp; if (a[j] >= mod) a[j] -= mod; (rt *= chg) %= mod; } } } if (idft) { LL inv_n = qpow(n, mod - 2); for (int i = 0; i < n; ++i) (a[i] *= inv_n) %= mod; } }; NTT(a, n, false); NTT(b, n, false); for (int i = 0; i < n; ++i) c[i] = a[i] * b[i] % mod; NTT(c, n, true); }; ``` 上面这份代码比较通俗易懂,但是参照小吴同学的代码,可以得到一份跑得更快,使用指针的代码: 感觉这份代码更通用,更短,orz WTC。 ```c++ #include <bits/stdc++.h> using namespace std; const int N = 1 << 21; const int mod = 998244353, RT = 3; using ll = long long; ll qpow(ll a, ll b) { ll res = 1; for (; b; b >>= 1, a = a * a % mod) if (b & 1) res = res * a % mod; return res; } int G[N], invG[N], rev[N]; void init(int n) { static int lst_n = 0; if (n == lst_n) return; lst_n = n; for (int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0); for (int i = 1; i < n; i <<= 1) { int g1 = qpow(RT, (mod - 1) / (i << 1)), ig1 = qpow(g1, mod - 2); int g = 1, ig = 1; for (int j = i; j < i + i; ++j) { G[j] = g, invG[j] = ig; g = 1ull * g * g1 % mod, ig = 1ull * ig * ig1 % mod; } } } void ntt(int *a, int n, bool inv) { for (int i = 0; i < n; ++i) if (rev[i] < i) swap(a[rev[i]], a[i]); for (int i = 1, x; i < n; i <<= 1) { for (int *j = a; j < a + n; j += (i << 1)) { for (int *k = j, *buf = (inv ? invG : G) + i; k < j + i; ++k, ++buf) { x = 1ull * k[i] * *buf % mod; if ((k[i] = *k + mod - x) >= mod) k[i] -= mod; if ((*k += x) >= mod) *k -= mod; } } } if (inv) { int invn = qpow(n, mod - 2); for (int i = 0; i < n; ++i) a[i] = 1ull * a[i] * invn % mod; } } void mul(int deg, int *a, int *b) { int n = 1; while (n <= deg) n <<= 1; init(n); ntt(a, n, false), ntt(b, n, false); for (int i = 0; i < n; ++i) a[i] = 1ull * a[i] * b[i] % mod; ntt(a, n, true); } int a[N], b[N]; int main() { cin.tie(0)->sync_with_stdio(false); int n, m; cin >> n >> m; for (int i = 0; i <= n; ++i) cin >> a[i]; for (int i = 0; i <= m; ++i) cin >> b[i]; mul(n + m, a, b); for (int i = 0; i <= n + m; ++i) cout << a[i] << ' '; cout << '\n'; } ```