FFT & NTT总结(个人笔记)
柒葉灬
2019-01-03 16:06:28
# FFT & NTT总结
---------
$FFT$ 模板(**短但有精度问题的**):
```cpp
int limit,l,r[maxm];
struct Complex{
double x,y;
Complex(double xx=0,double yy=0){
x=xx;y=yy;
}
Complex operator +(const Complex &b)const{
return Complex(x+b.x,y+b.y);
}
Complex operator -(const Complex &b)const{
return Complex(x-b.x,y-b.y);
}
Complex operator *(const Complex &b)const{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
}a[maxm],b[maxm];
void FFT(Complex *A,int type){
for(int i=0;i<limit;i++)
if(i<r[i])swap(A[i],A[r[i]]);
for(int i=1;i<limit;i<<=1){
Complex T(cos(Pi/i),type*sin(Pi/i));
for(int j=0;j<limit;j+=i<<1){
Complex t(1,0);
for(int k=0;k<i;k++,t=t*T){
Complex x=A[j+k],y=t*A[j+k+i];
A[j+k]=x+y;
A[j+k+i]=x-y;
}
}
}
}
void calc(long long *ans,long long *A,long long *B,int len){
limit=1,l=0;
while(limit<=len<<1)limit<<=1,l++;
for(int i=1;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<=len;i++){
a[i].x=A[i];
b[i].x=B[i];
}
FFT(a,1);
FFT(b,1);
for(int i=0;i<=limit;i++)
a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=len<<1;i++)
ans[i]=(long long)(a[i].x/limit+0.5);
for(int i=0;i<=limit;i++)
a[i].x=a[i].y=b[i].x=b[i].y=0;
}
```
$NTT$ 模板:
```cpp
void NTT(long long *A,int type){
for(int i=0;i<limit;i++)
if(i<r[i])swap(A[i],A[r[i]]);
for(int i=1;i<limit;i<<=1){
long long T=qpow(type==1?G:Gi,(P-1)/(i<<1));
for(int j=0;j<limit;j+=i<<1){
long long t=1;
for(int k=0;k<i;k++,t=t*T%P){
long long x=A[j+k],y=t*A[j+k+i]%P;
A[j+k]=(x+y)%P;
A[j+k+i]=(x-y+P)%P;
}
}
}
}
```
其中:$P$ 表示模数,$G$ 表示这个模数的原根,$Gi$ 表示原根的逆元。
----------------
## 当求多项式相乘的时候有限制条件时怎么办?
例子:把下列 $O(n^2)$ 代码用 $FFT$ 改成 $O(nlog_2n)$
```cpp
void calc(){
for(int i=0;i<len;i++)
for(int j=i+1;j<=len;j++)
C[i+j]+=A[i]*B[j];
int ans=0;
for(int i=1;i<=len<<1;i++)
ans+=C[i];
cout<<ans<<endl;
}
```
不难发现,上列式子 $C[i+j]+=A[i] \times B[j]$ 中,要求$i < j $
但简单的 $FFT$ 并不能实现这个**限制**功能。
这时候我们需要强行把一个变成负数。
即: $C[i-j]+=A[i] \times B[-j]$
因为$i-j<0$,所以在负数的地方加上$len$
即: $C[i-j+len]+=A[i] \times B[len-j]$
$i-j<0$,所以$i-j+len<len$,
下标$[0,len-1]$的地方就是我们要的答案。
#### 得到最终答案:翻转B数组,C[0 -> len-1]就是合法的答案。
>拓展 :
>若$i \leq j$ 则是$C[0 -> len]$
>若$i>j$ 则是$C[len+1 -> len*2]$
>若$i \geq j$ 则是$C[len -> len*2]$
--------
上面的例子太简单了,换一个难的。
```cpp
void calc(){
for(int i=0;i<len;i++)
for(int j=i+1;j<=len;j++)
C[i+j]+=A[i]*B[j];
for(int i=0;i<=len<<1;i++)
cout<<C[i]<<" ";
}
```
~~上面的,不会......~~
_update in 2019/7/30 :_
#### 可以用类似分治的操作左半区间的$A_i$乘上右半区间的$B_i$,
#### 复杂度$O(nlog^2n)$
-------
任意模数:
```cpp
void calc(int n,int m){
limit=1,l=0;
while(limit<=n+m)limit<<=1,l++;
for(int i=1;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<=n;i++){
A[i].x=num1[i]&16383;
C[i].x=num1[i]>>14;
}
for(int i=0;i<=m;i++){
B[i].x=num2[i]&16383;
D[i].x=num2[i]>>14;
}
FFT(A,1);FFT(B,1);FFT(C,1);FFT(D,1);
for(int i=0;i<=limit;i++){
Complex a1=A[i],a2=C[i],b1=B[i],b2=D[i];
A[i]=a1*b1;
B[i]=a1*b2;
C[i]=a2*b1;
D[i]=a2*b2;
}
FFT(A,-1);FFT(B,-1);FFT(C,-1);FFT(D,-1);
for(int i=0;i<=n+m;i++){
long long x1=A[i].x+0.5,x2=B[i].x+0.5,x3=C[i].x+0.5,x4=D[i].x+0.5;
x1%=P;x2%=P;x3%=P;x4%=P;//!!!
res[i]=(x1%P+(x2<<14)%P+(x3<<14)%P+(x4<<28)%P)%P;
}
for(int i=0;i<=limit;i++){
A[i].x=A[i].y=B[i].x=B[i].y=C[i].x=C[i].y=D[i].x=D[i].y=0;
}
}
```
>上面的代码调用了$8$次FFT,太慢了。
### 下面是重点要背的
### 其中不仅是多模数的处理
### 还有是精度处理良好的FFT
#### 优化:
```cpp
void FFT(Complex *A){
for(int i=1;i<limit;i++)
if(i<r[i])swap(A[i],A[r[i]]);
t[0].x=1;
for(int i=1;i<limit;i<<=1){
Complex T=Complex{cos(Pi/i),sin(Pi/i)};
for(int j=i-2;j>=0;j-=2){
t[j]=t[j>>1];
t[j+1]=T*t[j];
}
for(int j=0;j<limit;j+=i<<1){
for(int k=0;k<i;k++){
Complex x=A[j+k],y=t[k]*A[j+k+i];
A[j+k]=x+y;
A[j+k+i]=x-y;
}
}
}
}
void calc(int len){
limit=1,l=0;
while(limit<=len<<1)limit<<=1,l++;
for(int i=1;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<=len;i++){
A[i].x=num1[i]&32767;
A[i].y=num1[i]>>15;
B[i].x=num2[i]&32767;
B[i].y=num2[i]>>15;
}
for(int i=len+1;i<limit;i++)
A[i].clear(),B[i].clear();
FFT(A);FFT(B);
for(int i=0;i<limit;i++){
int j=(limit-1)&(limit-i);
C[j]=(Complex){0.5*(A[i].x+A[j].x),0.5*(A[i].y-A[j].y)}*B[i];
D[j]=(Complex){0.5*(A[i].y+A[j].y),0.5*(A[j].x-A[i].x)}*B[i];
}
FFT(C);FFT(D);
for(int i=len;i<=len<<1;i++){
long long x1=C[i].x/limit+0.5;
long long x2=C[i].y/limit+0.5;
long long x3=D[i].x/limit+0.5;
long long x4=D[i].y/limit+0.5;
x1%=P;x2%=P;x3%=P;x4%=P;
res[i]=(x1+(x2<<15)+(x3<<15)+(x4<<30))%P;
}
}
```
-----