FFT和NTT
快速傅里叶变换 (FFT) 和快速数论变换 (NTT)
-1.前言
离散傅里叶变换(Discrete Fourier Transform,缩写为 DFT),是傅里叶变换在时域和频域上都呈离散的形式。
FFT 是一种高效实现 DFT 的算法,称为 快速傅立叶变换(Fast Fourier Transform,FFT)。它对傅里叶变换的理论并没有新的发现,但是对于在计算机系统或者说数字系统中应用离散傅立叶变换,可以说是进了一大步。快速数论变换 (NTT) 是快速傅里叶变换(FFT)在数论基础上的实现。
这是 OI Wiki 中对 FFT 与 NTT 的概述。简单来说,FFT 与 NTT 可以快速计算多项式乘法,在信奥赛中属于多项式的基础。
由于本人太弱了,在看了许多博客后还是不太明白,所以通过这篇文章帮自己梳理一下。我自己写的肯定不如大佬们写的严谨,说不定会有很多错误,但也许会更好理解一些?希望也可以帮到你。
0.参考资料
这里是我看到的写的不错的文章和视频,讲解的很清楚,至少把我这个蒟蒻讲懂了……
-
傅里叶变换 (FFT) 学习笔记 —— command_block
-
NTT 与多项式全家桶 —— command_block
-
快速傅里叶变换 (FFT)——有史以来最巧妙的算法?
-
快速数论变换(NTT)超详解 —— 星夜
-
快速傅里叶变换 (FFT) 详解 —— 自为风月马前卒
-
快速数论变换 (NTT) 小结 —— 自为风月马前卒
1.多项式乘法
首先给出模板题
对于最简单的多项式乘法计算,大家肯定都知道,比如:
这是最简单的多项式乘法计算方法了。虽然简单,但它的劣势也很明显,太慢了。每次计算时间复杂度为
有没有什么更快的方法呢?
我们从另一个角度来看多项式乘法。
众所周知,多项式是可以画成函数图像形式的,比如:
有一条待定系数法的推论:
在平面直角坐标系中,
n+1 个点就能唯一确定一个n 次多项式。
那么,如果在这个二次函数图像上取三个点,也可以唯一表示这个多项式。这叫做多项式的点值表示。
比如:
-
w^k_n=(w^1_n)^k -
w^a_n \times w^b_n=w^{a+b}_n -
w^{2k}_{2n}=w^k_n -
w^{k+n/2}_n=-w^k_n
最后一个怎么理解呢?就相当于是转了半个圆周,关于原点中心对称,所以取了相反数。
有了这些性质,我们就可以把单位根愉快地代入之前的公式里面了!
现在,是时候用单位根来拯救我们的算法了! 先回忆一下公式吧:
把我们的单位根代进去:
化简一下?
所以最终式子为:
没太看懂?我们举个例子:
假如我们要求
即:
所以我们需要知道的就是
好耶,这样我们就能愉快的在
3.IDFT
前面已经完成了将多项式转换成点值表示的过程,那现在我们只需要反着来一次,将结果的点值表示重新转化为多项式,我们的 FFT 就大功告成了!
我在这里选择直接将结论告诉你,再证明结论的正确性。
我不进行正向的推导,是因为这里需要用到范德蒙德矩阵的相关知识,
而我不会而几乎所有我找到的博客中都没有对这部分进行展开。
注意!接下来涉及很多式子,一定要分清每一个多项式的含义
我们一开始的多项式为
我们如果将这个点值表示当作一个多项式的系数,即有一个多项式
我们再求出
如果你不太理解
w^{-k}_n ,我可以告诉你它是等于w^{n-k}_n 的,你可以试着把它画在图上看看。
那么我们考虑最暴力的那种
再把
慢慢化简:
后面那坨也太麻烦了,把它简化一下,我们设
我们现在先来看
当
当
两式相减:
显然,当
那么,再回到之前的大式子:
当
当
所以最终我们的式子就成了:
即:
看出来了吧?所以 IDFT 的过程其实就是又做了一遍 DFT,只不过这次代入的值为 (方便背诵 !为什么会这么巧?我不到啊
总之,我们现在理论部分都结束了,接下来就是代码实现了!
4.FFT 代码实现
为了保存一个多项式,我们直接开一个数组,数组第 i 位就表示多项式第 i 次项的系数即可(所以下标要从 0 开始!别忘了常数项!)。
- 还有一件事.mp3。我们既然要进行分治,要想保证每次都可以分下去,
n 必须是 2 的整次幂。所以不足的地方我们补一些 0 上去就好了。你可以理解为在后面加了一堆0x^k ,不影响结果。 - 还有一件事.mp3。单位根我们怎么求?利用三角函数(没学的点我),
w^1_n 即可表示为(\cos\frac{2\pi}{n})+(\sin\frac{2\pi}{n})i 。之后w^k_n 由w^1_n 乘k 次求出即可。 - 还有一件事.mp3。虽然
c++里有复数的 STL,但常数较大,万一被卡常不就寄了吗?所以还是自己写吧。
既然是分治算法,我们用递归来实现就好了。
这里的实现代码见大佬的傅里叶变换 (FFT) 学习笔记 。
因为没有优化的版本常数过大,无法通过模板题,所以我并没有写这里的代码 qwq。 (就是懒
我来介绍接下来的优化思路!
优化
递归的方式也太慢了,有没有方法可以减少数组的拷贝?
来观察一下经过这个分治,数组变成什么样了:
(借大佬图一用
观察一下原序列和反转后的序列
他们的下标,有没有什么性质?用瞪眼法
我们需要求的序列实际是原序列下标的二进制反转!
比如
所以我们可以预处理最后数列中每个数的位置,改变序列后倒着回去,就避免了大量的数组拷贝。
至于二进制反转,我们可以线性递推。
假如一个数
把这个二进制拆成两部分:
前面部分的反转你已经在之前求过了只要再判断一下二进制最后一位即可。
放代码理解一下:
for(int i=0;i<n;++i){
ver[i]=(ver[i>>1]>>1)|((i&1)?(n>>1):0);//n 这里已经保证是 2 的整次幂了
}
这样就完成了!比递归版的快了不少!
FFT 代码
放完整代码:
#include<bits/stdc++.h>
#define ll long long
#define pi acos(-1.0)//这样可以得到准确的 pi,大佬说的 qwq
using namespace std;
struct qwq{
double shi,xu;
}f[420005],g[420005];
int ver[420005];
int n,m,len;
qwq operator + (qwq x,qwq y){
qwq re;
re.shi=x.shi+y.shi;
re.xu=x.xu+y.xu;
return re;
}
qwq operator - (qwq x,qwq y){
qwq re;
re.shi=x.shi-y.shi;
re.xu=x.xu-y.xu;
return re;
}
qwq operator * (qwq x,qwq y){
qwq re;
re.shi=x.shi*y.shi-x.xu*y.xu;
re.xu=x.shi*y.xu+x.xu*y.shi;
return re;
}
inline int read(){
int w=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9'){
w=(w<<1)+(w<<3)+(ch^48);
ch=getchar();
}
return w;
}
inline void fft(qwq *a,double how){
for(int i=0;i<len;++i){
if(i<ver[i]) swap(a[i],a[ver[i]]);
}//利用二进制的对称交换位置
for(int k=2;k<=len;k<<=1){
qwq sum;
sum.shi=cos(2.0*pi/k);
sum.xu=sin(2.0*pi/k)*how;//处理一个 k 次单位根
for(int i=0;i<len;i+=k){
qwq xx={1,0};
for(int j=i;j<i+(k>>1);++j){
qwq num=xx*a[j+(k>>1)];
a[j+(k>>1)]=a[j]-num;
a[j]=a[j]+num;
xx=xx*sum;//得到下一个 k 次单位根
}
}
}
return;
}
int main()
{
n=read();
m=read();
for(int i=0;i<=n;++i) f[i].shi=read();
for(int i=0;i<=m;++i) g[i].shi=read();
n=n+m+1;
for(len=1;len<n;len<<=1);//保证 len 是 2 的整次幂
for(int i=0;i<len;++i){
ver[i]=(ver[i>>1]>>1)|(i&1?(len>>1):0);
}//预处理二进制反转
fft(f,1.0);
fft(g,1.0);
for(int i=0;i<len;++i) f[i]=f[i]*g[i];
fft(f,-1.0);
for(int i=0;i<n;++i){
printf("%d ",(int)(f[i].shi/len+0.49));//别忘了除 len!+0.49 是为了避免精度误差
}
return 0;
}
好耶!FFT 完结撒花!
还有一个优化,叫“三次变两次”,可以去大佬的傅里叶变换 (FFT) 学习笔记学一下 qwq。
(没错我懒
5.NTT
FFT 都够快了,为什么还要学 NTT 啊?
因为 FFT 用到了浮点数,不可避免的存在着精度丢失问题。而且,因为复数的使用,它的常数也很大。所以我们需要用 NTT 来减少精度误差,也可以使算法跑得更快一些。
NTT 和 FFT 的想法其实是一样的,或者说,式子都是一样的,只不过代入的数不一样。
欸?难道除了复数,还有符合我们要求的数吗?
乍一看确实没有,但是如果在模意义下呢?
又需要介绍一些新东西了。
原根
先放出定义:
设
m 是正整数,a 是整数,若a 模m 的阶等于\varphi(m) ,则称a 为模m 的一个原根
看不懂没关系,我们慢慢来介绍。
(如果你忘了欧拉定理,最好还是去看一眼吧)
先来说说阶。
你肯定知道,如果
所以,我们可以知道:
即,
很明显,它的最小循环节是 3,而不是 6。
比如
换句话说,阶是同余方程
那什么是原根呢?
如果
也就是说,
所以在上面的一个例子中,2 并不是 7 的一个原根。
回顾一下我们用到的单位根的性质:
- 对任意的 n,
w^0_n=1 w^k_n=(w^1_n)^k w^a_n \times w^b_n=w^{a+b}_n w^{2k}_{2n}=w^k_n w^{k+n/2}_n=-w^k_n
如果
在这里,如果把
为了方便,我们记
所以最后一条可以写成:
所以我们只需要证明:
下面开始证明: 根据欧拉定理,我们知道:
即
所以
而
至此,你会发现
还有一件事.mp3
为了实现递归过程,我们的
很巧的是,我们的老朋友
足够我们用了。它的最小原根是 3。
其实更巧的是,114514 也是 998244353 的一个原根……
所以,我们同样可以代入
所以变一下式子:
我们需要做的就是打出代码了!
6.NTT 代码实现
没啥好说的了,和 FFT 基本是一样的。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int p=998244353;
const int G=3;//原根
inline int read(){
int w=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9'){
w=(w<<1)+(w<<3)+(ch^48);
ch=getchar();
}
return w;
}
int n,m,len;
int ver[420005];
ll f[420005],g[420005];
inline ll poww(ll x,int y){
ll re=1,sum=x;
for(int i=y;i;i>>=1){
if(i&1) (re*=sum)%=p;
(sum*=sum)%=p;
}
return re;
}
inline void ntt(ll *a,int opt){
for(int i=0;i<len;++i){
if(i<ver[i]) swap(a[i],a[ver[i]]);
}
for(int k=2;k<=len;k<<=1){
ll sum=poww((opt==1)?G:poww(G,p-2),(p-1)/k);
for(int i=0;i<len;i+=k){
ll num=1;
for(int j=i;j<i+(k>>1);++j){
ll tt=num*a[j+(k>>1)]%p;
a[j+(k>>1)]=a[j]-tt;
if(a[j+(k>>1)]<0) a[j+(k>>1)]+=p;
a[j]+=tt;
if(a[j]>=p) a[j]-=p;
num=num*sum%p;
}
}
}
return;
}
int main(){
n=read();
m=read();
for(int i=0;i<=n;++i) f[i]=read();
for(int i=0;i<=m;++i) g[i]=read();
n=n+m+1;
for(len=1;len<n;len<<=1);
for(int i=0;i<len;++i){
ver[i]=(ver[i>>1]>>1)|((i&1)?(len>>1):0);
}
ntt(f,1);
ntt(g,1);
for(int i=0;i<len;++i){
f[i]=f[i]*g[i]%p;
}
ntt(f,-1);
ll invn=poww(len,p-2);
for(int i=0;i<n;++i){
printf("%lld ",f[i]*invn%p);
}
return 0;
}
代码基本是一模一样的。 好耶!NTT 完结撒花!
(终于写完了