MTT 初探

· · 个人记录

P4245 【模板】任意模数多项式乘法

这是一道板子题,但是我因为连续写了 2 次 3 模数 NTT 失败了,所以就换了一种写法。

\Large \text{3 模数 }NTT

既然是浅谈,那么也不扯那么多了。

首先我们考虑对两个式子进行 FFT,发现精度直接爆炸,这里直接飙到 10^{23}

因为我们最后乘出来的数最多到 len \times mod \times mod 所以直接爆炸

之后我们考虑 3 模数 NTT,也就是选 3 个模数,之后进行 CRT 合并。

这里建议看一下别人的博客,因为笔者真的菜,没写出来。

但是关于 CRT 和 EXCRT 的文章大家可以看看。

\Large \color{red}MTT

首先我们考虑将题目的要求表示出来

P = \sum_{i = 0} ^ nP_ix^i, Q = \sum_{i = 0} ^ m Q_i x^i

P \times Q 我们直接入正题

考虑将其拆成几部分

A = P_i >> 15, B = P_i \& (32767) C = Q_i >> 15, D = Q_i \& (32767) P*Q=AC*2^{30}+(AD+BC)*2^{15}+BD

之后我们考虑通过复数来表示,设 F = A + iB, G = C + iD

这样我们可以先通过 FFT 之后再将其表示成答案的形式,最后在计算。

但是我们发现,需要表示之前的式子需要 \overline{F}, \overline{G}

所以我们还需要 6 次 FFT,我们考虑优化这个过程。

\text{证明:P = A + iB, Q = A - iB} \color{red}\text{注意这里 P, Q 与之前的不同}

我们假设 P_t 表示 P 的 DFT

\begin{aligned} P_t(x) &= A(w_n^x) + iB(w_n^k) \\ & =\sum\limits_{j=0}^{n-1}A_j\omega _n^{jk}+iB_j\omega _n^{jk} \\ &=\sum\limits_{j=0}^{n-1}(A_J+iB_j)\omega _n^{jk}\\ \\ Q_t(x)=&A(\omega _n^k)-i B(\omega _n^k) \\ =&\sum\limits_{j=0}^{n-1}A_j\omega _n^{jk}-iB_j\omega _n^{jk} \\ =&\sum\limits_{j=0}^{n-1}(A_J-iB_j)\omega _n^{jk} \\ =&\sum\limits_{j=0}^{n-1}(A_j-i*B_j)(cos(\frac{2\pi jk}{n})+i*sin(\frac{2\pi jk}{n})) \\ =&\sum\limits_{j=0}^{n-1}(A_jcos(\frac{2\pi jk}{n})+B_jsin(\frac{2\pi jk}{n}))+i(A_jsin(\frac{2\pi jk}{n})-B_jcos(\frac{2\pi jk}{n})) \\ =&conj(\sum\limits_{j=0}^{n-1}(A_jcos(\frac{2\pi jk}{n})+B_jsin(\frac{2\pi jk}{n}))-i(A_jsin(\frac{2\pi jk}{n})-B_jcos(\frac{2\pi jk}{n}))) \\ =&conj(\sum\limits_{j=0}^{n-1}(A_jcos(\frac{-2\pi jk}{n})-B_jsin(\frac{-2\pi jk}{n}))-i(A_jsin(\frac{-2\pi jk}{n})+B_jcos(\frac{-2\pi jk}{n}))) \\ =&conj(\sum\limits_{j=0}^{n-1}(A_j+iB_j)(cos(\frac{-2\pi jk}{n})+i*sin(\frac{-2\pi jk}{n}))) \\ =&conj(\sum\limits_{j=0}^{n-1}(A_j+iB_j)\omega _n^{-jk}) \\ =&conj(\sum\limits_{j=0}^{n-1}(A_j+iB_j)\omega _n^{(n-j)k}) \\ =&conj(P_t(n-k)) \end{aligned}

从这里看着打的

个人认为这篇博客挺好,但是有些细节没讲到。

其中 conj(P) 也表示 P 的共轭复数

由此我们已经可以通过 F 快速求出 \overline{F}

那么我们考虑构造一下最终的答案吧!

\begin{cases} F = A + iB, G = C + iD \\ ans = P*Q=AC*2^{30}+(AD+BC)*2^{15}+BD \\ F_0 = B, F_1 = A \\ G_0 = D, G_1 = C \\ \end{cases}

所以我们就可以组合出 ans 了。 代码实现中是用 a 来存 ACb 的实部来存 AD + BCb 的虚部来存 BD

注意我们有些时候只要用到实部或虚部,我们可以通过一个复数和其共轭复数来进行消元。

\begin{cases} F = A + iB \\ \overline{F} = A - iB \\ F_0 = -i\times\frac{F- \overline{F}}{2} = B \\ F_1 = \frac{F + \overline{F}}{2} = A \\ \end{cases}

G 处理的方法是相同的。

那我们来看一下代码吧!

Code
#include <bits/stdc++.h>
using namespace std;
template <typename T>
void r1(T &x) {
    x = 0;
    char c(getchar());
    int f(1);
    for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
    for(; isdigit(c);c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    x *= f;
}
#define int long long
const int maxn = (1 << 18) + 5;
const int maxm = maxn << 1;

typedef int room[maxn];

int n, m;
struct Complex {
    double x, y;
    Complex(double tx = 0, double ty = 0) {x = tx, y = ty;}
    Complex operator + (const Complex &z) const {return Complex(x + z.x, y + z.y);}
    Complex operator - (const Complex &z) const {return Complex(x - z.x, y - z.y);}
    Complex operator * (const Complex &z) const {return Complex(x * z.x - y * z.y , z.x * y + x * z.y);}
    Complex operator * (const double &z) const {return Complex(x * z, y * z);}
    Complex operator ~ () const {return Complex(x, -y);} // 这里表示 取共轭复数
}f[maxn], g[maxn], w[maxn], a[maxn], b[maxn];
int p, rev[maxn];
int lim(1), len(0);
void fft(Complex *A) { // 正常 FFT
    for(int i = 0; i < lim; ++ i) if(i < rev[i]) swap(A[i], A[rev[i]]);
    for(int mid = 1; mid < lim; mid <<= 1) {
        for(int j = 0; j < lim; j += (mid << 1)) {
            for(int k = 0; k < mid; ++ k) {
                const Complex x = A[j + k], y = A[j + k + mid] * w[k + mid];
                A[j + k] = x + y;
                A[j + k + mid] = x - y;
            }
        }
    }
}
const double pi = acos(-1.0);
signed main() {
    int i, j;
    r1(n), r1(m), r1(p);
    for(i = 0; i <= n; ++ i) {
        int x;
        r1(x);
        x %= p;
        f[i].x = (x >> 15), f[i].y = (x & 32767);
    }
    for(i = 0; i <= m; ++ i) {
        int x;
        r1(x);
        x %= p;
        g[i].x = (x >> 15), g[i].y = (x & 32767);
    }
    while(lim <= n + m) lim <<= 1, ++ len;
    for(i = 0; i < lim; ++ i) rev[i] = rev[i >> 1] >> 1 | ((i & 1) <<(len - 1));
    for(i = 1; i < lim; i <<= 1) {
        w[i] = Complex(1, 0);
        for(j = 1; j < i; ++ j)
            if((j & 31) == 1) w[i + j] = Complex(cos(pi * j / i), sin(pi * j / i));
            else w[i + j] = w[i + j - 1] * w[i + 1];
    }
    fft(f), fft(g);
    for(i = 0; i < lim; ++ i) {
        Complex q, f0, g0, f1, g1;
        q = ~f[i ?lim - i : 0], f0 = (f[i] - q) * Complex(0, -0.5), f1 = (f[i] + q) * Complex(0.5, 0);
        q = ~g[i ?lim - i : 0], g0 = (g[i] - q) * Complex(0, -0.5), g1 = (g[i] + q) * Complex(0.5, 0);
        a[i] = f1 * g1, b[i] = g0 * f1 + f0 * g1 + (f0 * g0) * Complex(0, 1);
    } // 这里和之前的表示一样
    reverse(a + 1, a + lim), reverse(b + 1, b + lim);
    fft(a), fft(b);
    double k = 1.0 / lim;
    // 注意最后还要乘 1 / lim
    for(i = 0; i <= n + m; ++ i) printf("%lld ", (((int)(a[i].x * k+ 0.5) % p << 30) + ((int)(b[i].x * k+ 0.5) % p << 15) + ((int)(b[i].y * k+ 0.5) % p)) % p);
    return 0;
}