yhx-12243 的 NTT 究竟写了些什么(详细揭秘)

moongazer

2021-04-02 16:36:37

Personal

这是 yhx-12243 的 NTT ```cpp inline int & reduce(int &x) {return x += x >> 31 & mod;} inline int & neg(int &x) {return x = (!x - 1) & (mod - x);} u64 PowerMod(u64 a, int n, u64 c = 1) {for (; n; n >>= 1, a = a * a % mod) if (n & 1) c = c * a % mod; return c;} namespace poly_base { int l, n; u64 iv; vec w2; void init(int n = N, bool dont_calc_factorials = true) { int i, t; for (inv[1] = 1, i = 2; i < n; ++i) inv[i] = u64(mod - mod / i) * inv[mod % i] % mod; if (!dont_calc_factorials) for (*finv = *fact = i = 1; i < n; ++i) fact[i] = (u64)fact[i - 1] * i % mod, finv[i] = (u64)finv[i - 1] * inv[i] % mod; t = min(n > 1 ? lg2(n - 1) : 0, 21), *w2 = 1, w2[1 << t] = PowerMod(31, 1 << (21 - t)); for (i = t; i; --i) w2[1 << (i - 1)] = (u64)w2[1 << i] * w2[1 << i] % mod; for (i = 1; i < n; ++i) w2[i] = (u64)w2[i & (i - 1)] * w2[i & -i] % mod; } inline void NTT_init(int len) {n = 1 << (l = len), iv = mod - (mod - 1) / n;} void DIF(int *a) { int i, *j, *k, len = n >> 1, R, *o; for (i = 0; i < l; ++i, len >>= 1) for (j = a, o = w2; j != a + n; j += len << 1, ++o) for (k = j; k != j + len; ++k) R = (u64)*o * k[len] % mod, reduce(k[len] = *k - R), reduce(*k += R - mod); } void DIT(int *a) { int i, *j, *k, len = 1, R, *o; for (i = 0; i < l; ++i, len <<= 1) for (j = a, o = w2; j != a + n; j += len << 1, ++o) for (k = j; k != j + len; ++k) reduce(R = *k + k[len] - mod), k[len] = u64(*k - k[len] + mod) * *o % mod, *k = R; } inline void DNTT(int *a) {DIF(a);} inline void IDNTT(int *a) { DIT(a), std::reverse(a + 1, a + n); for (int i = 0; i < n; ++i) a[i] = a[i] * iv % mod; } } ``` 它为什么跑这么快?DIT 和 DIF 在干啥?预处理的原根为何和大多数人的不一样?这篇文章将为你解开这一奥秘( 先来看 init 函数 `w2[1 << t] = PowerMod(31, 1 << (21 - t));` 为什么是 $31$? 我们发现 $31^{2^{23}}=1$ 同时它模 $998244353$ 的阶是 $2^{23}$ 的倍数,也就是说它在进行 NTT 时和 $3^{119}$ 具有相似的性质,事实上,这里的确可以换为 $3^{119}$。 平时我的写法都要预处理 $21$ 种原根的次幂,为什么这里只用处理一种原根呢?我们将 $31$ 改为 $3^{119}$ 输出一下这段代码预处理的原根前 $8$ 项,发现结果如下: ``` 1 911660635 372528824 488723995 929031873 373294451 628914303 661054123 ``` 再来看平常写法预处理的原根: ``` 1: 1 2: 1 911660635 4: 1 372528824 911660635 488723995 8: 1 929031873 372528824 628914303 911660635 373294451 488723995 661054123 ``` 我们发现对这一结果蝴蝶变换(二进制翻转)可以得到如下结果: ``` 1: 1 2: 1 911660635 4: 1 911660635 372528824 488723995 8: 1 911660635 372528824 488723995 929031873 373294451 628914303 661054123 ``` 我们发现 $1$ 是 $2$ 的前缀,$2$ 是 $4$ 的前缀…… 经过冷静思考,我们发现这是显然的,蝴蝶变换是 $0$ 不动,偶数放左边,奇数放右边,分别进行少一位的蝴蝶变换,而根据 $\omega_{2n}^{2i}=\omega_n^i$ 所以它前一半就是对 $\frac{n}{2}$ 范围的原根做蝴蝶变换的结果。 代码在做什么也很好懂了,预处理出 $g^{2^k}$ 放在 $2^{21-k}$ 处(即蝴蝶变换后的结果),再递推得到其他结果($g^{2^j+2^k}=g^{2^j}\times g^{2^k}$,二进制翻转后也可以这样找每个为 $1$ 的位乘上)。 这样预处理原根有什么用?等下就知道了。 我们还要知道它的基本原理:DIT/DIF。在 rushcheyo 学长《转置原理及其应用》中我们了解到 DIT(decimation in time,按时域抽取)-FFT 可以将蝴蝶变换后的系数向量转化为点值向量; DIF(decimation in frequency,按频域抽取)-FFT 可以将系数向量转化为蝴蝶变换后的点值向量,二者互为置换。 Update on 2023.01.13: 更新了一点内容,请在[这篇文章](https://blog.seniorious.cc/2023/relearn-FFT/)查看。 我们发现可以用 DIF 实现 DFT,用 DIT 实现 IDFT 于是我们就不用进行蝴蝶变换了。 这是我写的一份朴素的 DIT/DIF-NTT: ```cpp void init_Poly() { for (int l = 1; l < (1 << 21); l <<= 1) { gw[l] = 1; int gn = pow(g, (Mod - 1) / (l << 1), Mod); for (int j = 1; j < l; ++j) { gw[l | j] = 1ll * gw[l | (j - 1)] * gn % Mod; } } } void DIT(int *A, int lim, bool flag) { for (int l = 1; l < lim; l <<= 1) { int *k = A; for (int i = 0; i < lim; i += (l << 1), k += (l << 1)) { int *x = k; for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) { int o = 1ll * x[l] * *g % Mod; x[l] = (*x + Mod - o) % Mod, *x = (*x + o) % Mod; } } } int iv = pow(lim, Mod - 2, Mod); for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * iv % Mod; std::reverse(A + 1, A + lim); } void DIF(int *A, int lim, bool flag) { for (int l = lim / 2; l >= 1; l >>= 1) { int *k = A; for (int i = 0; i < lim; i += (l << 1), k += (l << 1)) { int *x = k; for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) { int o = x[l]; x[l] = 1ll * (*x + Mod - o) * *g % Mod, *x = (*x + o) % Mod; } } } } ``` 这里的原根是最朴素的处理方式,而在进行 DIT/DIF 的时候,我们需要移动 $\operatorname{O}(n\log n)$ 次原根,而 yhx-12243 的 DIT/DIF 只需要移动 $\operatorname{O}(n)$ 次。 我们还发现一件神奇的事:yhx-12243 的 DIT 除了最外层 $len$ 的枚举顺序,似乎都在做 DIF,而 DIF 除了最外层 $len$ 的枚举顺序,似乎都在做 DIT! 这是一张 DIT-FFT 和 DIF-FFT 的示意图: ![DIT-FFT 和 DIF-FFT 的示意图](https://cdn.luogu.com.cn/upload/image_hosting/bxyzprz3.png) 我们观察到 DIT-FFT 时如果对系数向量进行了蝴蝶变换,对 $(0,4)$ 操作变为了对 $(0,1)$ 操作,对 $(4,6)$ 操作变为了对 $(1,3)$ 操作,如果不对系数向量做蝴蝶变换并保持原先的操作呢(即仍然是对 $(0,4)$ 操作,对 $(4,6)$ 操作)?好像这样仍然会得到一个点值数组,这个点值数组正是蝴蝶变换后的点值数组! 原因是简单的:观察到蝴蝶变换的置换 $A$ 有:$A^{-1}=A$ 对于输入的系数数组做这一置换,运算过程不变,那么答案也应当也被做了该置换,于是 $A\circ A=I$(输入),$I\circ A=A$(答案)。 而原先要找的原根,也要对应的蝴蝶变换一下,这时候预处理蝴蝶变换后的原根的作用就体现出来了! 更为重要的是,对于一个 $len$ 覆盖到的范围,所用的原根次幂是相同的(例如第一层变换中的 $(0,4),(1,5),(2,6),(3,7)$,第二层变换中的 $(0,2),(1,3)$ 和 $(4,6),(5,7)$) 以上内容可以手画一下长为 $16$ 的 DIT-FFT 来加深理解。 于是按从大到小枚举 $len$ 的顺序做 DIT,干的就是 DIF 的事,同理我们也可以得到按从小到大枚举 $len$ 的顺序做 DIF,干的就是 DIT 的事,而这种做法因为只需要移动 $T(n+\frac{n}{2}+\frac{n}{4}+\cdots)=\operatorname{O}(n)$ 次原根所以会比原先快一些。 下面进行一些~~可能并不靠谱的~~效率差异比较(以下三份代码都使用 `unsigned long long` 优化,即用 ull 存储中间结果减少取模): 1. [朴素 FFT](https://duck.ac/submission/17229) 279.439 ms,代码 2.43 KB 2. [DIT-DIF FFT](https://duck.ac/submission/17228) 212.99 ms,代码 2.93 KB 3. [优化 DIT-DIF FFT](https://duck.ac/submission/17226) 192.85 ms,代码 2.94 KB 可见 DIT-DIF FFT 相较于朴素 FFT 相比,有较大优化,而优化 DIT-DIF FFT 相较于 DIT-DIF FFT 有小幅度优化,且代码不长,实现难度不大,不失为一种较好的简单 NTT 实现方式。