MTT学习笔记--题解 P4245 【【模板】任意模数NTT】
4次FFT的MTT做法。权当给自己看的了。
此题系数范围很大,如果用朴素的FFT会掉精度,从而我们需要将一个整数
然后就表示出
于是我们要求四个多项式的积:
考虑先求出
因为它们都是实数,不妨以虚数的形式将其两两结合。
有:
考虑借助第一个
先求
下面试图证明
然后就可以通过一次
再可以用普通线性方程求出
求
这里一共两次FFT。
然后得到
得到
代码如下。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
typedef long long ll;
typedef long double ldb;
const ldb pi = acos(-1);
const int MAXN = 400001;
const ll B = 1 << 15;
inline int read() {
int x = 0,f = 1; char ch = getchar();
while(ch > '9' || ch < '0') { if(ch == '-') f = -1; ch = getchar(); }
do x = x * 10 + ch - 48, ch = getchar(); while(ch >= '0' && ch <= '9');
return x * f;
}
struct Complex
{
ldb x,y;
Complex() { x = y = 0; }
Complex(const ldb _x,const ldb _y) : x(_x), y(_y) {}
Complex operator +(const Complex a) { return Complex(x + a.x,y + a.y); }
Complex operator -(const Complex a) { return Complex(x - a.x,y - a.y); }
Complex operator *(const Complex a) { return Complex(x * a.x - y * a.y,x * a.y + y * a.x); }
Complex operator *(const ldb t) { return Complex(x * t,y * t); }
Complex operator /(const ldb t) { return Complex(x / t,y / t); }
};
Complex I(0,1);
int n,m,p,r[MAXN]; Complex wn[MAXN];
Complex a0[MAXN],a1[MAXN],b0[MAXN],b1[MAXN];
Complex P[MAXN],Q[MAXN];
void FFT(Complex *a,int N) {
for(int i = 0;i < N;i++) if(i < r[i]) std::swap(a[i],a[r[i]]);
Complex t1,t2;
for(int n = 2, m = 1;n <= N;m = n, n <<= 1)
for(int l = 0;l < N;l += n)
for(int i = l;i < l + m;i++) {
Complex w = wn[N / m * (i - l)];
t1 = a[i], t2 = w * a[i + m];
a[i] = t1 + t2;
a[i + m] = t1 - t2;
}
return;
}
void IFFT(Complex *a,int N) {
FFT(a,N);
std::reverse(a + 1,a + N);
for(int i = 0;i < N;i++) a[i] = a[i] / N;
return;
}
void ConnectFFT(Complex *a,Complex *b,int N) {
for(int i = 0;i < N;i++) a[i].y = b[i].x;
FFT(a,N);
b[0] = Complex(a[0].x,-a[0].y);
for(int i = 1;i < N;i++) b[i] = Complex(a[N - i].x,-a[N - i].y);
Complex t1,t2;
for(int i = 0;i < N;i++) {
t1 = a[i], t2 = b[i];
a[i] = (t1 + t2) / 2.0;
b[i] = (t2 - t1) * I / 2.0;
}
return;
}
ll rd(const ldb x) {
if(x >= 0) return ll(x + 0.5) % p;
else return ll(x - 0.5) % p;
}
int main() {
n = read(), m = read(), p = read();
for(int i = 0,v;i <= n;i++) v = read() % p, a0[i].x = v / B, a1[i].x = v % B;
for(int i = 0,v;i <= m;i++) v = read() % p, b0[i].x = v / B, b1[i].x = v % B;
int N = 1, l = -1; while(N <= n + m) N <<= 1, l++;
for(int i = 1;i < N;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
for(int i = 0;i < N;i++) wn[i] = Complex(cos(pi / N * i),sin(pi / N * i));
ConnectFFT(a0,a1,N); ConnectFFT(b0,b1,N);
for(int i = 0;i < N;i++) {
P[i] = a0[i] * b0[i] + I * a1[i] * b0[i];
Q[i] = a0[i] * b1[i] + I * a1[i] * b1[i];
}
IFFT(P,N); IFFT(Q,N);
for(int i = 0;i <= n + m;i++) {
ll ans = B * B % p * rd(P[i].x) % p +
B * rd(Q[i].x + P[i].y) % p +
rd(Q[i].y) % p;
std::printf("%lld ",ans % p);
}
return 0;
}