FFT & NTT总结(个人笔记)

柒葉灬

2019-01-03 16:06:28

Personal

# FFT & NTT总结 --------- $FFT$ 模板(**短但有精度问题的**): ```cpp int limit,l,r[maxm]; struct Complex{ double x,y; Complex(double xx=0,double yy=0){ x=xx;y=yy; } Complex operator +(const Complex &b)const{ return Complex(x+b.x,y+b.y); } Complex operator -(const Complex &b)const{ return Complex(x-b.x,y-b.y); } Complex operator *(const Complex &b)const{ return Complex(x*b.x-y*b.y,x*b.y+y*b.x); } }a[maxm],b[maxm]; void FFT(Complex *A,int type){ for(int i=0;i<limit;i++) if(i<r[i])swap(A[i],A[r[i]]); for(int i=1;i<limit;i<<=1){ Complex T(cos(Pi/i),type*sin(Pi/i)); for(int j=0;j<limit;j+=i<<1){ Complex t(1,0); for(int k=0;k<i;k++,t=t*T){ Complex x=A[j+k],y=t*A[j+k+i]; A[j+k]=x+y; A[j+k+i]=x-y; } } } } void calc(long long *ans,long long *A,long long *B,int len){ limit=1,l=0; while(limit<=len<<1)limit<<=1,l++; for(int i=1;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); for(int i=0;i<=len;i++){ a[i].x=A[i]; b[i].x=B[i]; } FFT(a,1); FFT(b,1); for(int i=0;i<=limit;i++) a[i]=a[i]*b[i]; FFT(a,-1); for(int i=0;i<=len<<1;i++) ans[i]=(long long)(a[i].x/limit+0.5); for(int i=0;i<=limit;i++) a[i].x=a[i].y=b[i].x=b[i].y=0; } ``` $NTT$ 模板: ```cpp void NTT(long long *A,int type){ for(int i=0;i<limit;i++) if(i<r[i])swap(A[i],A[r[i]]); for(int i=1;i<limit;i<<=1){ long long T=qpow(type==1?G:Gi,(P-1)/(i<<1)); for(int j=0;j<limit;j+=i<<1){ long long t=1; for(int k=0;k<i;k++,t=t*T%P){ long long x=A[j+k],y=t*A[j+k+i]%P; A[j+k]=(x+y)%P; A[j+k+i]=(x-y+P)%P; } } } } ``` 其中:$P$ 表示模数,$G$ 表示这个模数的原根,$Gi$ 表示原根的逆元。 ---------------- ## 当求多项式相乘的时候有限制条件时怎么办? 例子:把下列 $O(n^2)$ 代码用 $FFT$ 改成 $O(nlog_2n)$ ```cpp void calc(){ for(int i=0;i<len;i++) for(int j=i+1;j<=len;j++) C[i+j]+=A[i]*B[j]; int ans=0; for(int i=1;i<=len<<1;i++) ans+=C[i]; cout<<ans<<endl; } ``` 不难发现,上列式子 $C[i+j]+=A[i] \times B[j]$ 中,要求$i < j $ 但简单的 $FFT$ 并不能实现这个**限制**功能。 这时候我们需要强行把一个变成负数。 即: $C[i-j]+=A[i] \times B[-j]$ 因为$i-j<0$,所以在负数的地方加上$len$ 即: $C[i-j+len]+=A[i] \times B[len-j]$ $i-j<0$,所以$i-j+len<len$, 下标$[0,len-1]$的地方就是我们要的答案。 #### 得到最终答案:翻转B数组,C[0 -> len-1]就是合法的答案。 >拓展 : >若$i \leq j$ 则是$C[0 -> len]$ >若$i>j$ 则是$C[len+1 -> len*2]$ >若$i \geq j$ 则是$C[len -> len*2]$ -------- 上面的例子太简单了,换一个难的。 ```cpp void calc(){ for(int i=0;i<len;i++) for(int j=i+1;j<=len;j++) C[i+j]+=A[i]*B[j]; for(int i=0;i<=len<<1;i++) cout<<C[i]<<" "; } ``` ~~上面的,不会......~~ _update in 2019/7/30 :_ #### 可以用类似分治的操作左半区间的$A_i$乘上右半区间的$B_i$, #### 复杂度$O(nlog^2n)$ ------- 任意模数: ```cpp void calc(int n,int m){ limit=1,l=0; while(limit<=n+m)limit<<=1,l++; for(int i=1;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); for(int i=0;i<=n;i++){ A[i].x=num1[i]&16383; C[i].x=num1[i]>>14; } for(int i=0;i<=m;i++){ B[i].x=num2[i]&16383; D[i].x=num2[i]>>14; } FFT(A,1);FFT(B,1);FFT(C,1);FFT(D,1); for(int i=0;i<=limit;i++){ Complex a1=A[i],a2=C[i],b1=B[i],b2=D[i]; A[i]=a1*b1; B[i]=a1*b2; C[i]=a2*b1; D[i]=a2*b2; } FFT(A,-1);FFT(B,-1);FFT(C,-1);FFT(D,-1); for(int i=0;i<=n+m;i++){ long long x1=A[i].x+0.5,x2=B[i].x+0.5,x3=C[i].x+0.5,x4=D[i].x+0.5; x1%=P;x2%=P;x3%=P;x4%=P;//!!! res[i]=(x1%P+(x2<<14)%P+(x3<<14)%P+(x4<<28)%P)%P; } for(int i=0;i<=limit;i++){ A[i].x=A[i].y=B[i].x=B[i].y=C[i].x=C[i].y=D[i].x=D[i].y=0; } } ``` >上面的代码调用了$8$次FFT,太慢了。 ### 下面是重点要背的 ### 其中不仅是多模数的处理 ### 还有是精度处理良好的FFT #### 优化: ```cpp void FFT(Complex *A){ for(int i=1;i<limit;i++) if(i<r[i])swap(A[i],A[r[i]]); t[0].x=1; for(int i=1;i<limit;i<<=1){ Complex T=Complex{cos(Pi/i),sin(Pi/i)}; for(int j=i-2;j>=0;j-=2){ t[j]=t[j>>1]; t[j+1]=T*t[j]; } for(int j=0;j<limit;j+=i<<1){ for(int k=0;k<i;k++){ Complex x=A[j+k],y=t[k]*A[j+k+i]; A[j+k]=x+y; A[j+k+i]=x-y; } } } } void calc(int len){ limit=1,l=0; while(limit<=len<<1)limit<<=1,l++; for(int i=1;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); for(int i=0;i<=len;i++){ A[i].x=num1[i]&32767; A[i].y=num1[i]>>15; B[i].x=num2[i]&32767; B[i].y=num2[i]>>15; } for(int i=len+1;i<limit;i++) A[i].clear(),B[i].clear(); FFT(A);FFT(B); for(int i=0;i<limit;i++){ int j=(limit-1)&(limit-i); C[j]=(Complex){0.5*(A[i].x+A[j].x),0.5*(A[i].y-A[j].y)}*B[i]; D[j]=(Complex){0.5*(A[i].y+A[j].y),0.5*(A[j].x-A[i].x)}*B[i]; } FFT(C);FFT(D); for(int i=len;i<=len<<1;i++){ long long x1=C[i].x/limit+0.5; long long x2=C[i].y/limit+0.5; long long x3=D[i].x/limit+0.5; long long x4=D[i].y/limit+0.5; x1%=P;x2%=P;x3%=P;x4%=P; res[i]=(x1+(x2<<15)+(x3<<15)+(x4<<30))%P; } } ``` -----