FFT优化技巧——MTT

SuperJvRuo

2018-08-05 18:35:43

Personal

## 一、MTT的适用范围 模板题:[P4245 【模板】任意模数NTT](https://www.luogu.org/problemnew/show/P4245) 现在算法竞赛的毒瘤们常常要求对某一个数取膜。如果使用FFT直接求解,```long double```也会出现不小的误差。如果取膜的数是一个比DFT长度大的、形如$a\times2^k+1$的数,那么我们可以求出原根,代替复根,进行DFT和IDFT。这种算法被称为数论变换(Number Theorem Transform(NTT))。 然而,这样的膜数毕竟很少,NTT的适用面较窄。如果遇到随便的一个膜数,一种方法是选择三个满足NTT性质的膜数,分别求出这三个膜数意义下的卷积,再使用CRT确定答案。 这种方法的效率并不让人满意。同时,对于初学DFT的选手们来说,原根、CRT这样的知识也难以迅速全部理解。其实,要对所得的多项式取膜并保证精度,还有一种更易懂,更易实现,效率比三膜数NTT更高的实现方法。这种方法就是MTT(毛啸变换(Matthew Transform))。取膜和不取膜的多项式卷积都可以使用MTT。 ## 二、MTT的基本思想 拆系数。 以洛谷P4245任意膜数NTT为例,发现系数的范围在$10^9$以内。可将两个多项式的系数都拆成$a\times32768+b$的形式。 比如计算多项式$A(x)$和$B(x)$在$mod\space p$意义下的卷积。将$A(x)$的系数序列$a$拆为两个序列$k[a]$和$b[a]$,其中$k[a_i]=a_i\div32768,b[a_i]=a_i\space mod\space32768$。设这两个新序列形成的多项式为$C,D$。对$B$也进行相同操作,形成多项式$E,F$。 不难发现,整个多项式在$n$处的点值表示即为 $$A(n)\cdot B(n)=(32768C(n)+D(n))\cdot(32768E(n)+F(n))$$ $$=1073741824C(n)\cdot E(n)+32768(C(n)\cdot F(n)+E(n)\cdot D(n))+D(n)\cdot F(n)$$ 为了保证精度,我们计算$C(n)\cdot E(n),C(n)\cdot F(n),E(n)\cdot D(n),D(n)\cdot F(n)$四个卷积。再按上式计算点值。 ## 三、MTT的实现细节 我们预先算出上面$C,D,E,F$多项式的DFT,再求出4个卷积的点值表示,再进行4次IDFT算出四个卷积,然后按贡献相加。这种最朴素的MTT共需要4次复数意义下的DFT和4次复数意义下的IDFT。 注意到$C(n)\cdot F(n),E(n)\cdot D(n)$的贡献相同,可以直接计算$C(n)\cdot F(n)+E(n)\cdot D(n)$的点值表示,再经过一次IDFT求出其系数。这样可以减少一次IDFT。 既然是复数计算,就要考虑精度。这里一定要使用$std::sin$和$std::cos$,一定要加$std::$。这样精度更高。 ```cpp #include<cstdio> #include<cmath> #include<cctype> #include<cstring> #include<algorithm> #define LL long long int Read() { int x=0;char c=getchar(); while(!isdigit(c)) { c=getchar(); } while(isdigit(c)) { x=x*10+(c^48); c=getchar(); } return x; } const long double PI=acos(-1); struct Complex { long double r,i; Complex(long double R=0,long double I=0) { r=R; i=I; } }ak[1000005],ab[1000005],bk[1000005],bb[1000005],A[1000005],B[1000005],C[1000005],D[1000005]; Complex operator * (Complex a,Complex b) { return Complex(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r); } Complex operator + (Complex a,Complex b) { return Complex(a.r+b.r,a.i+b.i); } Complex operator - (Complex a,Complex b) { return Complex(a.r-b.r,a.i-b.i); } int rev[1000005]; void getrev(int bit) { for(int i=0;i<(1<<bit);i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); } void fft(Complex* a,int n,int dft) { for(int i=0;i<n;i++) if(i<rev[i]) std::swap(a[i],a[rev[i]]); for(int step=1;step<n;step<<=1) { Complex omega(std::cos(dft*PI/step),std::sin(dft*PI/step)); for(int j=0;j<n;j+=step<<1) { Complex omega_k(1,0); for(int k=j;k<j+step;k++) { Complex x=a[k]; Complex y=omega_k*a[k+step]; a[k]=x+y; a[k+step]=x-y; omega_k=omega_k*omega; } } } if(dft==-1) { for(int i=0;i<n;i++) a[i].r/=n; } } int num_a[100005],num_b[100005]; int main() { int n=Read(),m=Read(),p=Read(); int bit=1,s=2; for(bit=1;(1<<bit)<n+m+1;++bit) s<<=1; for(int i=0;i<=n;++i) { num_a[i]=Read()%p; ak[i].r=num_a[i]>>15; ab[i].r=num_a[i]&0x7fff; } for(int i=0;i<=m;++i) { num_b[i]=Read()%p; bk[i].r=num_b[i]>>15; bb[i].r=num_b[i]&0x7fff; } getrev(bit); fft(ak,s,1); fft(ab,s,1); fft(bk,s,1); fft(bb,s,1); for(int i=0;i<s;++i) { A[i]=ak[i]*bk[i]; B[i]=ak[i]*bb[i]+ab[i]*bk[i]; D[i]=ab[i]*bb[i]; } fft(A,s,-1); fft(B,s,-1); fft(D,s,-1); for(int i=0;i<=m+n;++i) { printf("%lld ",((((LL)(A[i].r/1+0.5)%p)<<30)+(((LL)(B[i].r/1+0.5)%p)<<15)+(LL)(D[i].r/1+0.5)%p)%p); } return 0; } ``` ## 四、进一步的优化 事实上,MTT还可以继续优化,在毛啸的论文中,最多可以减至1.5次DFT和2次IDFT。 这里让我们来一步步优化。 首先介绍一种将DFT两两合并的技巧, 考虑对长度为$n=2^n$的实多项式$A(x),B(x)$进行DFT。 定义: $$P(x)=A(x)+iB(x)$$ $$Q(x)=A(x)-iB(x)$$ 其中$i=\sqrt{-1}$。 那么有: $$DFT(A)=\frac{DFT(P)+DFT(Q)}{2}$$ $$DFT(B)=i\frac{DFT(P)-DFT(Q)}{2}$$ 如果能在求出$DFT(P(x))$后迅速求出$DFT(Q(x))$就好了,这样两次DFT可以优化为一次。 我们可以进行这样的推导: ``` 即得易见平凡,仿照上例显然。留作习题答案略,读者自证不难。 反之亦然同理,推论自然成立。略去过程Q.E.D.,由上可知证毕。 ``` (以后会完善证明过程) 可知: $$Q(\omega^k_n)=conj(P(\omega^{n-k}_n))$$ 其中$conj(x)$表示$x$的共轭复数。这样,可以把两次实多项式的DFT合并成一次DFT。IDFT也可以作同样的优化。 这样,4次DFT可以优化成2次,3次IDFT中选2次合并成1次,总共2次IDFT。这样的效率已经相当高了。 这两次DFT还可以继续优化。 (鸽子:别急,我正在写呢) ## 五、神仙代码 [毛啸在UOJ多项式乘法一题中的提交记录](http://uoj.ac/submission/49836) ``` #include <bits/stdc++.h> using namespace std; #define REP(i, a, b) for (int i = (a), _end_ = (b); i < _end_; ++i) #define debug(...) fprintf(stderr, __VA_ARGS__) #define mp make_pair #define x first #define y second #define pb push_back #define SZ(x) (int((x).size())) #define ALL(x) (x).begin(), (x).end() template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; } template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; } typedef long long LL; const int oo = 0x3f3f3f3f; const int Mod = 1e9 + 7; const int max0 = 262144; struct comp { double x, y; comp(): x(0), y(0) { } comp(const double &_x, const double &_y): x(_x), y(_y) { } }; inline comp operator+(const comp &a, const comp &b) { return comp(a.x + b.x, a.y + b.y); } inline comp operator-(const comp &a, const comp &b) { return comp(a.x - b.x, a.y - b.y); } inline comp operator*(const comp &a, const comp &b) { return comp(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } inline comp conj(const comp &a) { return comp(a.x, -a.y); } const double PI = acos(-1); int N, L; comp w[max0 + 5]; int bitrev[max0 + 5]; void fft(comp *a, const int &n) { REP(i, 0, n) if (i < bitrev[i]) swap(a[i], a[bitrev[i]]); for (int i = 2, lyc = n >> 1; i <= n; i <<= 1, lyc >>= 1) for (int j = 0; j < n; j += i) { comp *l = a + j, *r = a + j + (i >> 1), *p = w; REP(k, 0, i >> 1) { comp tmp = *r * *p; *r = *l - tmp, *l = *l + tmp; ++l, ++r, p += lyc; } } } inline void fft_prepare() { REP(i, 0, N) bitrev[i] = bitrev[i >> 1] >> 1 | ((i & 1) << (L - 1)); REP(i, 0, N) w[i] = comp(cos(2 * PI * i / N), sin(2 * PI * i / N)); } inline void conv(int *x, int *y, int *z) { REP(i, 0, N) (x[i] += Mod) %= Mod, (y[i] += Mod) %= Mod; static comp a[max0 + 5], b[max0 + 5]; static comp dfta[max0 + 5], dftb[max0 + 5], dftc[max0 + 5], dftd[max0 + 5]; REP(i, 0, N) a[i] = comp(x[i] & 32767, x[i] >> 15); REP(i, 0, N) b[i] = comp(y[i] & 32767, y[i] >> 15); fft(a, N), fft(b, N); REP(i, 0, N) { int j = (N - i) & (N - 1); static comp da, db, dc, dd; da = (a[i] + conj(a[j])) * comp(0.5, 0); db = (a[i] - conj(a[j])) * comp(0, -0.5); dc = (b[i] + conj(b[j])) * comp(0.5, 0); dd = (b[i] - conj(b[j])) * comp(0, -0.5); dfta[j] = da * dc; dftb[j] = da * dd; dftc[j] = db * dc; dftd[j] = db * dd; } REP(i, 0, N) a[i] = dfta[i] + dftb[i] * comp(0, 1); REP(i, 0, N) b[i] = dftc[i] + dftd[i] * comp(0, 1); fft(a, N), fft(b, N); REP(i, 0, N) { int da = (LL)(a[i].x / N + 0.5) % Mod; int db = (LL)(a[i].y / N + 0.5) % Mod; int dc = (LL)(b[i].x / N + 0.5) % Mod; int dd = (LL)(b[i].y / N + 0.5) % Mod; z[i] = (da + ((LL)(db + dc) << 15) + ((LL)dd << 30)) % Mod; } } int main() { #ifndef ONLINE_JUDGE freopen("input.txt", "r", stdin); freopen("output.txt", "w", stdout); #endif int n, m; static int a[max0 + 5], b[max0 + 5], c[max0 + 5]; scanf("%d%d", &n, &m), ++n, ++m; REP(i, 0, n) scanf("%d", a + i); REP(i, 0, m) scanf("%d", b + i); L = 0; for ( ; (1 << L) < n + m - 1; ++L); N = 1 << L; fft_prepare(); conv(a, b, c); REP(i, 0, n + m - 1) (c[i] += Mod) %= Mod, printf("%d ", c[i]); printf("\n"); return 0; } ```