MTT学习笔记--题解 P4245 【【模板】任意模数NTT】

· · 题解

4次FFT的MTT做法。权当给自己看的了。

此题系数范围很大,如果用朴素的FFT会掉精度,从而我们需要将一个整数 a 分解成 a0\times B+a1 的形式,其中 B 可取常数,代码中是 2^{15}。这样取到了精度,但是会增加复杂度。

然后就表示出 {a0\times B+a1}{b0\times B+b1} 的乘法为:

B^2a0b0+B(a0b1+a1b0)+a1b1

于是我们要求四个多项式的积:a0b0,a1b1,a1b0,a0b1

考虑先求出 a0,b0,a1,b1 的点值。

因为它们都是实数,不妨以虚数的形式将其两两结合。

有:

A_x=a0_x+ia1_x B_x=b0_x+ib1_x

考虑借助第一个 A 如何求出 a0,a1 的点值。

先求 A 的点值为 A1。然后将 A 的系数共轭后再构造 A2

下面试图证明 A1 的第 k 项与 A2 的第 n-k 项是共轭的。

A1_k&=\sum_{i=0}^{n-1}(a0_i+ia1_i)(w_n^k)^i \\ &=\sum_{i=0}^{n-1}a0_i(w_n^k)^i+ia1_i(w_n^k)^i \\ &=\sum_{i=0}^{n-1}a0_i\cos\theta+ia0_i\sin\theta+ia1_i\cos\theta+i^2a1_i\sin\theta \\ &=\sum_{i=0}^{n-1}(a0_i\cos\theta-a1_i\sin\theta)+i(a0_i\sin\theta+a1_i\cos\theta) \end{aligned} A2_{n-k}&=\sum_{i=0}^{n-1}(a0_i-ia1_i)(w_n^{-k})^i \\ &=\sum_{i=0}^{n-1}a0_i(w_n^{-k})^i-ia1_i(w_n^{-k})^i \\ &=\sum_{i=0}^{n-1}a0_i\cos\theta-ia0_i\sin\theta-ia1_i\cos\theta+i^2a1_i\sin\theta \\ &=\sum_{i=0}^{n-1}(a0_i\cos\theta-a1_i\sin\theta)-i(a0_i\sin\theta+a1_i\cos\theta) \end{aligned}

然后就可以通过一次 A 的DFT求出 A1A2

再可以用普通线性方程求出 a0,a1 的点值了。

b0,b1 类似。

这里一共两次FFT。

然后得到 a0,a1,b0,b1 的点值,现在求两两乘积的系数,同样考虑合并:

P(x)=a0_xb0_x+ia1_xb0_x Q(x)=a0_xb1_x+ia1_xb1_x

得到 PQ 的点值后两次FFT求出系数,取出各部代入篇头的式子里即可。

代码如下。

#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>

typedef long long ll;
typedef long double ldb;
const ldb pi = acos(-1);
const int MAXN = 400001;
const ll B = 1 << 15;

inline int read() {
    int x = 0,f = 1; char ch = getchar();
    while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
    do x = x * 10 + ch - 48, ch = getchar(); while(ch >= '0' && ch <= '9');
    return x * f;
}

struct Complex
{
    ldb x,y;
    Complex() { x = y = 0; }
    Complex(const ldb _x,const ldb _y) : x(_x), y(_y) {}
    Complex operator +(const Complex a) { return Complex(x + a.x,y + a.y); }
    Complex operator -(const Complex a) { return Complex(x - a.x,y - a.y); }
    Complex operator *(const Complex a) { return Complex(x * a.x - y * a.y,x * a.y + y * a.x); }
    Complex operator *(const ldb t) { return Complex(x * t,y * t); }
    Complex operator /(const ldb t) { return Complex(x / t,y / t); }
};

Complex I(0,1);

int n,m,p,r[MAXN]; Complex wn[MAXN];
Complex a0[MAXN],a1[MAXN],b0[MAXN],b1[MAXN];
Complex P[MAXN],Q[MAXN];

void FFT(Complex *a,int N) {
    for(int i = 0;i < N;i++) if(i < r[i]) std::swap(a[i],a[r[i]]);
    Complex t1,t2;
    for(int n = 2, m = 1;n <= N;m = n, n <<= 1)
        for(int l = 0;l < N;l += n)
            for(int i = l;i < l + m;i++) {
                Complex w = wn[N / m * (i - l)];
                t1 = a[i], t2 = w * a[i + m];
                a[i] = t1 + t2;
                a[i + m] = t1 - t2;
            }
    return;
}

void IFFT(Complex *a,int N) {
    FFT(a,N);
    std::reverse(a + 1,a + N);
    for(int i = 0;i < N;i++) a[i] = a[i] / N;
    return;
}

void ConnectFFT(Complex *a,Complex *b,int N) {
    for(int i = 0;i < N;i++) a[i].y = b[i].x;
    FFT(a,N);
    b[0] = Complex(a[0].x,-a[0].y);
    for(int i = 1;i < N;i++) b[i] = Complex(a[N - i].x,-a[N - i].y);
    Complex t1,t2;
    for(int i = 0;i < N;i++) {
        t1 = a[i], t2 = b[i];
        a[i] = (t1 + t2) / 2.0;
        b[i] = (t2 - t1) * I / 2.0;
    }
    return;
}

ll rd(const ldb x) {
    if(x >= 0) return ll(x + 0.5) % p;
    else return ll(x - 0.5) % p;
}

int main() {
    n = read(), m = read(), p = read();
    for(int i = 0,v;i <= n;i++) v = read() % p, a0[i].x = v / B, a1[i].x = v % B;
    for(int i = 0,v;i <= m;i++) v = read() % p, b0[i].x = v / B, b1[i].x = v % B;
    int N = 1, l = -1; while(N <= n + m) N <<= 1, l++;
    for(int i = 1;i < N;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
    for(int i = 0;i < N;i++) wn[i] = Complex(cos(pi / N * i),sin(pi / N * i));
    ConnectFFT(a0,a1,N); ConnectFFT(b0,b1,N);
    for(int i = 0;i < N;i++) {
        P[i] = a0[i] * b0[i] + I * a1[i] * b0[i];
        Q[i] = a0[i] * b1[i] + I * a1[i] * b1[i];
    }
    IFFT(P,N); IFFT(Q,N);
    for(int i = 0;i <= n + m;i++) {
        ll ans = B * B % p * rd(P[i].x) % p +
                 B * rd(Q[i].x + P[i].y) % p +
                 rd(Q[i].y) % p;
        std::printf("%lld ",ans % p);
    }
    return 0;
}