我也要学这种东西吗?
wangxx2012 · · 算法·理论
快速傅里叶变换 FFT
模板题
梦开始的地方。
点表示法
首先,我们要了解到多项式的性质:
- 用任意
n+1 个不同点,均可唯一确定一个n 次多项式A\left(x\right)=a_0+a_1x+a_2x^2+\cdots+a_nx^n 。
:::info[证明]
我们令这
注意到这是一个
那么它有唯一解,即等价于它的系数矩阵的秩是满秩的,即等价于它的系数矩阵的行列式是不等于
这里给出它的行列式:
这个行列式的值等于
即等于
因为
命题得证。 :::
由于这个性质,我们即可将一个多项式的系数表示法转换为点表示法。
而 FFT 是取
单位根
我们引入复数。
:::info[复数的加法]
对于两个复数
:::
:::info[复数的乘法]
对于两个复数
:::
再引入复数意义下的
复数意义下的
则它有以下性质:
-
\omega_n^i \neq \omega_n^j,\forall i\neq j -
\omega_n^k=\cos \frac{2k\pi}{n}+i\sin \frac{2k\pi}{n} -
\omega_n^0=\omega_n^n=1 -
\omega_{2n}^{2k}=\omega_n^k -
\omega_{n}^{k+\frac{n}{2}}=-\omega_n^k
正变换
这一步是将系数表示法转换为点表示法,即我们将原先的
我们先将
则有
容易发现,
当
那么
当
那么
容易发现,当我们求
由于使用分治,则时间复杂度优化到
逆变换
这一步是将点表示法转换为系数表示法。
为了方便书写,我们将原先的点对
对于由这
:::info[证明] 考虑倒推
考虑构造一个多项式
当
则
则
由于
当
则
因此
命题得证。 :::
有了以上结论,我们考虑如何快速求解。
令
则
最后的
二进制翻转
经过上面的推导,我们知道 FFT 需要使用分治,即上文所说的按照系数下标的奇偶性进行分类。
为了实现迭代,我们需要考虑如何正确地实现分类。
先看下面的例子:
最上面的数字表示最开始位置下标的二进制,最下面的数字表示最后位置下标的二进制。
容易发现,从最开始到分治结束,二进制进行了翻转。
那么对于第
其中,
代码实现如下:
while((1<<bit)<n+m+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
蝶形运算
在 FFT 的分治合并过程中,我们利用以下两个公式同时计算出两个点的值:
其中
- 入:
A_1 和A_2 在相同位置k 上的值,以及旋转因子\omega_n^k 。 - 出:两个新值
A\left(\omega_n^k\right) 和A\left(\omega_n^{k+\frac{n}{2}}\right) 。
而在迭代实现中,我们可以按照长度从小到大逐层合并。对于当前层长度
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
配合上面的二进制翻转操作,我们可以成功地将 FFT 从递归转为迭代。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const double pi=acos(-1);
int n,m;
struct node{
double x,y;
node operator+ (const node& t)const{
return {x+t.x,y+t.y};
}
node operator- (const node& t)const{
return {x-t.x,y-t.y};
}
node operator* (const node& t)const{
return {x*t.x-y*t.y,x*t.y+y*t.x};
}
}a[maxn],b[maxn];
int rev[maxn],bit,tot;
void FFT(node a[],int inv){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
auto w1=node({cos(pi/mid),inv*sin(pi/mid)});
for(int i=0;i<tot;i+=mid*2){
auto wk=node({1,0});
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<=n;i++) cin>>a[i].x;
for(int i=0;i<=m;i++) cin>>b[i].x;
while((1<<bit)<n+m+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
FFT(a,1),FFT(b,1);
for(int i=0;i<tot;i++) a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=n+m;i++) cout<<(int)round(a[i].x/tot)<<" ";
return 0;
}
:::
快速数论变换 NTT
由于你谷并没有专门的 NTT 模版,这里用 FFT 的模板代替。(原因)
模板题
为了能够使得取模后不爆精度,考虑找到某个东西能满足单位根的性质来代替单位根。
那么,原根就登上了多项式的舞台。
给出两个正整数数
若有
若
令
则
且若
因此
所以
即我们可将单位根替换为原根。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int P=998244353,G=3;
int n,m;
int a[maxn],b[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=(base*x)%P;
x=(x*x)%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
int rev[maxn],bit,tot;
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P,a[i+j+mid]=(x-y+P)%P;
}
}
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<=n;i++) cin>>a[i],a[i]%=P;
for(int i=0;i<=m;i++) cin>>b[i],b[i]%=P;
while((1<<bit)<n+m+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
NTT(a,1),NTT(b,1);
for(int i=0;i<tot;i++) a[i]=(a[i]*b[i])%P;
NTT(a,-1);
int inv=qpow(tot,P-2);
for(int i=0;i<=n+m;i++) cout<<a[i]*inv%P<<" ";
return 0;
}
:::
值得注意的是,普通 NTT 的模数的必须形如
任意模数快速傅里叶变换 MTT
模板题
拆系数 FFT
由于取模,FFT 会爆精度。
而 NTT 又要求模数是友好模数,在这道题上不适用。
考虑改良原来的 FTT。
我们将每个系数拆分为低 15 位和高 15 位。
即
那么有
然后做 7 遍 FFT 即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define double long double
#define endl '\n'
using namespace std;
const int maxn=1e6+10;
const double pi=acos(-1);
int P;
int n,m;
int a[maxn],b[maxn],c[maxn];
struct node {
double x,y;
node operator+(const node&t)const {
return{x+t.x,y+t.y};
}
node operator-(const node&t)const {
return{x-t.x,y-t.y};
}
node operator*(const node&t)const {
return{x*t.x-y*t.y,x*t.y+y*t.x};
}
};
int rev[maxn],bit,tot;
node buf1[maxn],buf2[maxn],buf3[maxn],buf4[maxn];
void FFT(node a[],int inv) {
for(int i=0;i<tot;i++) {
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1) {
node w1={cos(pi/mid),inv*sin(pi/mid)};
for(int i=0;i<tot;i+=mid*2) {
node wk={1,0};
for(int j=0;j<mid;j++,wk=wk*w1) {
node x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(inv==-1) {
for(int i=0;i<tot;i++) {
a[i].x/=tot;
a[i].y/=tot;
}
}
}
int get_val(double x){
return ((long long)(round(x)))%P;
}
void MTT(int c[],int len){
bit=0;
while((1<<bit)<len) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<tot;i++) buf1[i]=buf2[i]=buf3[i]=buf4[i]={0,0};
for(int i=0;i<n;i++){
buf1[i]={(double)(a[i]&32767),0};
buf2[i]={(double)(a[i]>>15),0};
}
for(int i=0;i<m;i++){
buf3[i]={(double)(b[i]&32767),0};
buf4[i]={(double)(b[i]>>15),0};
}
FFT(buf1,1); FFT(buf2,1); FFT(buf3,1); FFT(buf4,1);
for(int i=0;i<tot;i++){
node tmp1=buf1[i]*buf3[i],tmp2=buf1[i]*buf4[i],tmp3=buf2[i]*buf3[i],tmp4=buf2[i]*buf4[i];
buf1[i]=tmp1;
buf2[i]=tmp2+tmp3;
buf3[i]=tmp4;
}
FFT(buf1,-1); FFT(buf2,-1); FFT(buf3,-1);
for(int i=0;i<len;i++){
int low=get_val(buf1[i].x),mid=get_val(buf2[i].x),high=get_val(buf3[i].x);
c[i]=(low+(mid<<15)+(high<<30))%P;
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m>>P;
n++; m++;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
for(int i=0;i<m;i++) cin>>b[i],b[i]%=P;
MTT(c,n+m-1);
for(int i=0;i<n+m-1;i++) cout<<c[i]<<" ";
return 0;
}
:::
三模 NTT
exgcd
三模 NTT 要使用 CRT,而 CRT 要用 exgcd(也不一定),所以先讲 exgcd。
exgcd 是用来求方程
的一组整数解的。
我们假设已经知道了方程
的一组整数解
又有
我们令
整理,得
对比我们想要求解的方程,我们惊喜地发现:
在代码中递归求解即可。
int exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return a;
}
int gcd=exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return gcd;
}
CRT
模板题
CRT 用来求解如下同余方程组:
考虑拆成如下几个方程组
则
然后令
那么想让
注意到这个东西用逆元求解即可。
但如何算出最小解呢?下面的话摘自 wsy_I 大神的学习笔记:
注意观察一个性质:假如这个
x 加上或者减去\prod^{n}_{i=1}a_i ,x 对于所有的a_i 取模之后的结果是相同的。如果加上或者减去的值不同于\prod^{n}_{i=1}a_i ,那么必定有某些取模之后的结果不同。于是构造最小解可以直接对于\prod^{n}_{i=1}a_i 取模。
温馨提示:模板题会爆 long long。
:::success[Code]
#include<bits/stdc++.h>
#define int __int128
#define endl '\n'
using namespace std;
const int maxn=10+5;
void write_(int x){
if(!x)return;
write_(x/10);
putchar(x%10+'0');
}
void write(int x){
if(!x)putchar('0');
else{
if(x<0) putchar('-'),x=-x;
write_(x);
}
}
inline int read(){
int x=0,y=0;char c;
while(!isdigit(c))y|=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
return y?-x:x;
}
void exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return;
}
exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return;
}
int n,a[maxn],b[maxn];
int ans,mod=1;
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
n=read();
for(int i=1;i<=n;i++) a[i]=read(),b[i]=read(),mod*=a[i];
for(int i=1;i<=n;i++){
int mul=1;
for(int j=1;j<=n;j++){
if(i==j) continue;
mul*=a[j];
}
int kx,ky;
exgcd(mul,a[i],kx,ky);
kx=(kx%a[i]+a[i])%a[i];
int y=kx*mul,x=y*b[i];
ans=(ans+x)%mod;
}
write(ans);
return 0;
}
:::
CRT+NTT
:::info[原理]
三模 NTT 是通过选取三个友好模数
这里选取的三个模数为:
优点是它们的原根都为
接下来考虑如何合并取模后的答案。我们使用 CRT。
对于同余方程组
其中
则有
其中
按照这样将
最后的答案输出真实系数值再对题目要求的
:::success[Code]
#include<bits/stdc++.h>
#define int __int128
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int mod1=998244353,mod2=(1<<26)*7+1,mod3=(1<<21)*479+1,G=3;
int a[maxn],b[maxn];
void write_(int x){
if(!x)return;
write_(x/10);
putchar(x%10+'0');
}
void write(int x){
if(!x)putchar('0');
else{
if(x<0) putchar('-'),x=-x;
write_(x);
}
}
inline int read(){
int x=0,y=0;char c=getchar();
while(!isdigit(c))y|=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
return y?-x:x;
}
int qpow(int x,int k,int mod){
int base=1;
while(k){
if(k&1) base=(base*x)%mod;
x=(x*x)%mod;
k>>=1;
}
return base%mod;
}
void NTT(int a[],int rev[],int type,int tot,int mod){
int Gi=qpow(G,mod-2,mod);
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(mod-1)/(mid<<1),mod);
else wi=qpow(Gi,(mod-1)/(mid<<1),mod);
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%mod){
int x=a[i+j],y=wk*a[i+j+mid]%mod;
a[i+j]=(x+y)%mod,a[i+j+mid]=(x-y+mod)%mod;
}
}
}
}
void exgcd(int a,int b,int &x,int &y){
if(!b){
x=1; y=0;
return;
}
exgcd(b,a%b,x,y);
int t=x;
x=y; y=t-a/b*y;
return;
}
int rev[maxn],bit,tot;
int n,m,p;
void MUL(int f[],int g[],int len1,int len2,int mod){
memset(rev,0,sizeof(rev)); bit=tot=0;
while((1<<bit)<len1+len2+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
NTT(f,rev,1,tot,mod),NTT(g,rev,1,tot,mod);
for(int i=0;i<tot;i++) f[i]=(f[i]*g[i])%mod;
NTT(f,rev,-1,tot,mod);
int inv=qpow(tot,mod-2,mod);
for(int i=0;i<=len1+len2;i++) f[i]=f[i]*inv%mod;
}
int CRT(int r1,int m1,int r2,int m2){
int k,t;
exgcd(m1,m2,k,t);
k=(k%m2+m2)%m2;
int d=((r2-r1)%m2+m2)%m2;
return r1+m1*((d*k)%m2);
}
int tmpa1[maxn],tmpa2[maxn],tmpa3[maxn];
int tmpb1[maxn],tmpb2[maxn],tmpb3[maxn];
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
n=read(),m=read(),p=read();
for(int i=0;i<=n;i++) a[i]=read(),tmpa1[i]=tmpa2[i]=tmpa3[i]=a[i];
for(int i=0;i<=m;i++) b[i]=read(),tmpb1[i]=tmpb2[i]=tmpb3[i]=b[i];
MUL(tmpa1,tmpb1,n,m,mod1);
MUL(tmpa2,tmpb2,n,m,mod2);
MUL(tmpa3,tmpb3,n,m,mod3);
int len=n+m+1;
for(int i=0;i<len;i++){
int r1=tmpa1[i],r2=tmpa2[i],r3=tmpa3[i];
int x12=CRT(r1,mod1,r2,mod2);
int mod12=mod1*mod2;
int x=CRT(x12,mod12,r3,mod3);
write(x%p),putchar(' ');
}
return 0;
}
:::
温馨提示:计算过程中会爆 long long。
快速沃尔什变换 FWT
模板题
核心思想
寻找一种可逆线性变换
然后通过逆变换得到
按位或
即求解
正变换
对于任意一个二进制数
则有
我们注意到展开后,
因为
至此,我们便得到了
逆变换
我们观察
并不是容易地发现这是个高维前缀和,还原用容斥即可,即
其中,
而在代码实现中,我们对每个二进制位(从低位到高位),如果该位是
按位与
即求解
正变换
和按位或类似,把
逆变换
直接给还原式子吧,因为和上面的类似。
而在代码实现中,我们对每个二进制位(从低位到高位),如果该位是
按位异或
即求解
直接给结论吧,想看详细推导的可以看此文实在看不懂。
正变换
代码实现步骤如下:
- 把数组分成两个一组(相邻两个数一组)。
- 对于每一组
(x,y) ,计算两个新数:- 新左边的数:
x+y 。 - 新右边的数:
x-y 。
- 新左边的数:
- 用这两个新数替换原来的数。
逆变换
代码实现步骤如下:
- 把数组分成两个一组(相邻两个数一组)。
- 对于每一组
(x,y) ,计算两个新数:- 新左边的数:
\frac{x+y}{2} 。 - 新右边的数:
\frac{x-y}{2} 。
- 新左边的数:
- 用这两个新数替换原来的数。
其实就比正变换多除了个
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=1<<17;
const int mod=998244353,inv2=(mod+1)/2;
int n;
int tot;
int a[maxn],b[maxn];
int tmp_a[maxn],tmp_b[maxn];
void FWT_or(int tmp[],int inv,int len){
for(int mid=1;mid<len;mid<<=1){
for(int i=0;i<len;i+=mid*2){
for(int j=0;j<mid;j++){
int x=tmp[i+j],y=tmp[i+j+mid];
if(inv==1) tmp[i+j+mid]=(x+y)%mod;
else tmp[i+j+mid]=(y-x+mod)%mod;
}
}
}
}
void FWT_and(int tmp[],int inv,int len){
for(int mid=1;mid<len;mid<<=1){
for(int i=0;i<len;i+=mid*2){
for(int j=0;j<mid;j++){
int x=tmp[i+j],y=tmp[i+j+mid];
if(inv==1) tmp[i+j]=(x+y)%mod;
else tmp[i+j]=(x-y+mod)%mod;
}
}
}
}
void FWT_xor(int tmp[],int inv,int len){
for(int mid=1;mid<len;mid<<=1){
for(int i=0;i<len;i+=mid*2){
for(int j=0;j<mid;j++){
int x=tmp[i+j],y=tmp[i+j+mid];
tmp[i+j]=(x+y)%mod; tmp[i+j+mid]=(x-y+mod)%mod;
if(inv==-1){
tmp[i+j]=tmp[i+j]*inv2%mod;
tmp[i+j+mid]=tmp[i+j+mid]*inv2%mod;
}
}
}
}
}
void init(){
memcpy(tmp_a,a,sizeof(tmp_a));
memcpy(tmp_b,b,sizeof(tmp_b));
}
void OR(int tmp1[],int tmp2[],int len){
FWT_or(tmp1,1,len); FWT_or(tmp2,1,len);
for(int i=0;i<len;i++) tmp1[i]=tmp1[i]*tmp2[i]%mod;
FWT_or(tmp1,-1,len);
for(int i=0;i<len;i++) cout<<tmp1[i]<<" ";
cout<<endl;
}
void AND(int tmp1[],int tmp2[],int len){
FWT_and(tmp1,1,len); FWT_and(tmp2,1,len);
for(int i=0;i<len;i++) tmp1[i]=tmp1[i]*tmp2[i]%mod;
FWT_and(tmp1,-1,len);
for(int i=0;i<len;i++) cout<<tmp1[i]<<" ";
cout<<endl;
}
void XOR(int tmp1[],int tmp2[],int len){
FWT_xor(tmp1,1,len); FWT_xor(tmp2,1,len);
for(int i=0;i<len;i++) tmp1[i]=tmp1[i]*tmp2[i]%mod;
FWT_xor(tmp1,-1,len);
for(int i=0;i<len;i++) cout<<tmp1[i]<<" ";
cout<<endl;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n; tot=1<<n;
for(int i=0;i<tot;i++) cin>>a[i];
for(int i=0;i<tot;i++) cin>>b[i];
init(); OR(tmp_a,tmp_b,tot);
init(); AND(tmp_a,tmp_b,tot);
init(); XOR(tmp_a,tmp_b,tot);
return 0;
}
:::
多项式乘法逆
模板题
这里用的是牛顿迭代法,牛顿迭代一般用于求一个函数的根(即与
牛顿迭代公式:
其中,
可以理解为利用函数在某点的切线来逼近函数本身,并通过不断求解该切线与
回到本题,我们考虑移项。
然后令
然后套牛顿迭代公式,得
又由于
化简,得
然后递推做 NTT 即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e6+10;
const int P=998244353,G=3;
int n;
int a[maxn],b[maxn],c[maxn];
int rev[maxn],bit,tot;
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i]; for(int i=n;i<tot;i++) c[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
INV(b,a,n);
for(int i=0;i<n;i++) cout<<b[i]<<' ';
return 0;
}
:::
:::info[时间复杂度分析]
设
展开递归:
令
由于
已知
:::
任意模数多项式乘法逆
模板题
把 NTT 换成 MTT 即可。
这里写的是拆系数 FFT。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define double long double
#define endl '\n'
using namespace std;
const int maxn=3e5+10;
const int P=1e9+7;
int n;
int a[maxn],b[maxn],c[maxn];
int rev[maxn],bit,tot;
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const double pi=acos(-1);
struct node {
double x,y;
node operator+(const node&t)const {
return{x+t.x,y+t.y};
}
node operator-(const node&t)const {
return{x-t.x,y-t.y};
}
node operator*(const node&t)const {
return{x*t.x-y*t.y,x*t.y+y*t.x};
}
};
node buf1[maxn],buf2[maxn],buf3[maxn],buf4[maxn];
void FFT(node a[],int inv) {
for(int i=0;i<tot;i++) {
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1) {
node w1={cos(pi/mid),inv*sin(pi/mid)};
for(int i=0;i<tot;i+=mid*2) {
node wk={1,0};
for(int j=0;j<mid;j++,wk=wk*w1) {
node x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(inv==-1) {
for(int i=0;i<tot;i++) {
a[i].x/=tot;
a[i].y/=tot;
}
}
}
int get_val(double x){
return (long long)(round(x))%P;
}
void MTT(int x[],int y[],int res[],int len){
bit=0;
while((1<<bit)<len) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<tot;i++) buf1[i]=buf2[i]=buf3[i]=buf4[i]={0,0};
for(int i=0;i<len;i++){
buf1[i]={(double)(x[i]&32767),0};
buf2[i]={(double)(x[i]>>15),0};
}
for(int i=0;i<len;i++){
buf3[i]={(double)(y[i]&32767),0};
buf4[i]={(double)(y[i]>>15),0};
}
FFT(buf1,1); FFT(buf2,1); FFT(buf3,1); FFT(buf4,1);
for(int i=0;i<tot;i++){
node tmp1=buf1[i]*buf3[i],tmp2=buf1[i]*buf4[i],tmp3=buf2[i]*buf3[i],tmp4=buf2[i]*buf4[i];
buf1[i]=tmp1;
buf2[i]=tmp2+tmp3;
buf3[i]=tmp4;
}
FFT(buf1,-1); FFT(buf2,-1); FFT(buf3,-1);
for(int i=0;i<len;i++){
int low=get_val(buf1[i].x),mid=get_val(buf2[i].x),high=get_val(buf3[i].x);
res[i]=(low+(mid<<15)+(high<<30))%P;
if(res[i]<0) res[i]+=P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
while(len<(n<<1)) len<<=1;
static int tmp[maxn];
for(int i=0;i<n;i++) tmp[i]=a[i];
for(int i=n;i<len;i++) tmp[i]=0;
for(int i=n;i<len;i++) b[i]=0;
MTT(b,b,c,len); MTT(c,tmp,c,len);
for(int i=0;i<n;i++) b[i]=(2*b[i]%P-c[i]+P)%P;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=0;i<n;i++) cin>>a[i],a[i]%=P;
INV(b,a,n);
for(int i=0;i<n;i++) cout<<b[i]<<' ';
return 0;
}
:::
多项式除法
模板题
要用一个很巧的变形。
然后令
这个式子用多项式求逆即可,问题来到如何根据
其实容易得到,
则有
又令
则有
至此,式子转换为底数之和的形式。 :::
例题一:[ZJOI2014] 力
容易把题目所求转换为
不妨令
则有
不妨又令
则有
前面部分已经是卷积形式了,先丢一边推后面部分。
注意到
再根据我们的求解套路,令
则有
总合一下:
至此,我们将所有式子转换为卷积形式,FFT 求解即可。
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=4e5+10;
const double pi=acos(-1);
int n;
struct node{
double x,y;
node operator+ (const node& t)const{
return {x+t.x,y+t.y};
}
node operator- (const node &t)const{
return {x-t.x,y-t.y};
}
node operator* (const node &t)const{
return {x*t.x-y*t.y,x*t.y+y*t.x};
}
}f1[maxn],f2[maxn],g[maxn];
node tmp1[maxn],tmp2[maxn];
int rev[maxn],bit,tot;
void FFT(node a[],int inv){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
auto w1=node({cos(pi/mid),inv*sin(pi/mid)});
for(int i=0;i<tot;i+=mid*2){
auto wk=node({1,0});
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=wk*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>f1[i].x,g[i].x=1.0/(i*i);
for(int i=1;i<=n;i++) f2[i]=f1[n-i+1];
while((1<<bit)<2*n+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
FFT(f1,1); FFT(g,1); FFT(f2,1);
for(int i=0;i<tot;i++) tmp1[i]=f1[i]*g[i],tmp2[i]=g[i]*f2[i];
FFT(tmp1,-1); FFT(tmp2,-1);
for(int i=1;i<=n;i++) printf("%.3lf\n",tmp1[i].x/tot-tmp2[n-i+1].x/tot);
return 0;
}
:::
例题二:[AHOI2017 / HNOI2017] 礼物
:::info[题意]{open}
给定两个长度为
考虑把平方展开,则有
注意到前面两项是定值,后面一项可用二次函数求最值解决。
问题来到如何解决
可以看做求解
先运用求解套路。
即有
其中
再考虑断环为链,即
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=3e5+10;
const double pi=acos(-1);
int n,m;
int ans,sumx,sumy;
struct node{
double x,y;
node operator+ (const node&t)const{
return {x+t.x,y+t.y};
}
node operator- (const node&t)const{
return {x-t.x,y-t.y};
}
node operator* (const node&t)const{
return {x*t.x-y*t.y,x*t.y+y*t.x};
}
}a1[maxn],a2[maxn],b[maxn];
int rev[maxn],bit,tot;
void fft(node a[],int inv){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
auto w1=node({cos(pi/mid),inv*sin(pi/mid)});
for(int i=0;i<tot;i+=mid*2){
auto wk=node({1,0});
for(int j=0;j<mid;j++,wk=wk*w1){
auto x=a[i+j],y=a[i+j+mid]*wk;
a[i+j]=x+y; a[i+j+mid]=x-y;
}
}
}
if(inv==-1){
for(int i=0;i<tot;i++) a[i].x/=tot;
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=0;i<n;i++){
cin>>a1[i].x;
ans+=a1[i].x*a1[i].x;
sumx+=a1[i].x;
}
for(int i=0;i<n;i++){
cin>>b[i].x;
b[n+i].x=b[i].x;
ans+=b[i].x*b[i].x;
sumy+=b[i].x;
}
for(int i=0;i<n;i++) a2[i].x=a1[n-i-1].x;
int len=3*n;
while((1<<bit) < len+1) bit++;
tot=1<<bit;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
fft(a2,1); fft(b,1);
for(int i=0;i<tot;i++) b[i]=b[i]*a2[i];
fft(b,-1);
double maxx=LLONG_MIN;
for(int k=0;k<n;k++) maxx=max(maxx,b[n-1+k].x);
int c=llround((sumy-sumx)*1.0/n);
cout<<ans-2*llround(maxx)+n*c*c+2*c*(sumx-sumy);
return 0;
}
:::
2. 优化 DP
此类问题一般通过 DP 转移方程构建生成函数,通过多项式解决。
没有生成函数基础的先看这里。
例题:付公主的背包
对于一个体积为
可以看做消耗
那么根据乘法原理可得,答案即为这
我们考虑把生成函数写成封闭形式,则有
那答案则为
然后有个很经典的技巧:将多项式取对数,可以将乘法转为加法。令
然后利用
那么对于一个体积
:::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=4e5+10;
const int P=998244353,G=3;
int n,m;
int fa[maxn],fb[maxn];
int c[maxn],d[maxn],e[maxn],tmp[maxn];
int rev[maxn],bit,tot;
int ln_G[maxn],tmp_G[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
const int Gi=qpow(G,P-2);
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
int df[maxn],inv_f[maxn],mul_res[maxn];
void LN(int f[],int g[],int n){
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
LN(g,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
int inv[maxn];
void init(int n){
inv[0]=inv[1]=1;
for(int i=2;i<n;i++) inv[i]=inv[P%i]*(P-P/i)%P;
}
int t[maxn];
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin>>n>>m;
init(m+5);
for(int i=1;i<=n;i++){
int x;
cin>>x; t[x]++;
}
for(int i=1;i<=m;i++){
if(t[i]){
for(int j=i;j<=m;j+=i) fa[j]=(fa[j]+t[i]*inv[j/i])%P;
}
}
EXP(fb,fa,m+1);
for(int i=1;i<=m;i++) cout<<fb[i]%P<<endl;
return 0;
}
:::
总结
本文讲得都是一些基础且常用的多项式操作以及一些经典例题,多项式多点求值和快速差值后面有时间会考虑补上。作者的代码常数较好,请放心食用。想看进阶的可以看 Great_Influence 的多项式总结。
给个整合 :::success[Code]
#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;
const int maxn=6000010;
const int P=998244353,G=3;
const int Gi=332748118;
int rev[maxn],bit,tot;
int c[maxn],d[maxn],e[maxn],tmp[maxn],tmp2[maxn];
int ln_G[maxn],tmp_G[maxn];
int tmp_b1[maxn],tmp_b2[maxn];
int fr[maxn],gr[maxn],inv_gr[maxn],qr[maxn],tmpf[maxn];
int qpow(int x,int k){
int base=1;
while(k){
if(k&1) base=base*x%P;
x=x*x%P;
k>>=1;
}
return base%P;
}
void NTT(int a[],int type){
for(int i=0;i<tot;i++){
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int mid=1;mid<tot;mid<<=1){
int wi;
if(type==1) wi=qpow(G,(P-1)/(mid<<1));
else wi=qpow(Gi,(P-1)/(mid<<1));
for(int i=0;i<tot;i+=mid*2){
int wk=1;
for(int j=0;j<mid;j++,wk=wk*wi%P){
int x=a[i+j],y=wk*a[i+j+mid]%P;
a[i+j]=(x+y)%P;
a[i+j+mid]=(x-y+P)%P;
}
}
}
if(type==-1){
int inv_n=qpow(tot,P-2);
for(int i=0;i<tot;i++) a[i]=a[i]*inv_n%P;
}
}
void INV(int b[],int a[],int n){
if(n==1){
b[0]=qpow(a[0],P-2);
return;
}
INV(b,a,(n+1)>>1);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) c[i]=a[i];
for(int i=n;i<tot;i++) c[i]=0;
for(int i=(n+1)/2;i<tot;i++) b[i]=0;
NTT(b,1); NTT(c,1);
for(int i=0;i<tot;i++) b[i]=(2-b[i]*c[i]%P+P)%P*b[i]%P;
NTT(b,-1);
for(int i=n;i<tot;i++) b[i]=0;
}
void MUL(int a[],int b[],int c[],int n,int m){
int len=1;
bit=0;
while(len<(n+m-1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) d[i]=a[i];
for(int i=n;i<tot;i++) d[i]=0;
for(int i=0;i<m;i++) e[i]=b[i];
for(int i=m;i<tot;i++) e[i]=0;
NTT(d,1); NTT(e,1);
for(int i=0;i<tot;i++) c[i]=d[i]*e[i]%P;
NTT(c,-1);
}
void DIFF(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i-1]=i*f[i]%P;
g[n-1]=0;
}
void INTG(int f[],int g[],int n){
for(int i=1;i<n;i++) g[i]=f[i-1]*qpow(i,P-2)%P;
g[0]=0;
}
void LN(int f[],int g[],int n){
static int df[maxn],inv_f[maxn],mul_res[maxn];
for(int i=0;i<n;i++) tmp[i]=f[i];
DIFF(f,df,n);
INV(inv_f,tmp,n);
MUL(df,inv_f,mul_res,n,n);
INTG(mul_res,g,n);
}
void EXP(int g[],int f[],int n){
if(n==1){
g[0]=1;
return;
}
EXP(g,f,(n+1)>>1);
for(int i=0;i<n;i++) tmp_G[i]=g[i];
LN(tmp_G,ln_G,n);
for(int i=0;i<n;i++) tmp_G[i]=(f[i]-ln_G[i]+P)%P;
tmp_G[0]=(tmp_G[0]+1)%P;
MUL(g,tmp_G,g,n,n);
for(int i=n;i<tot;i++) g[i]=0;
}
int w;
struct node{
int x,y;
node operator*(const node& t)const{
return{(x*t.x%P+y*t.y%P*w%P)%P,(x*t.y%P+y*t.x%P)%P};
}
};
bool check(int x){
if(x%P==0) return 1;
return qpow(x,(P-1)/2)==1;
}
int Cipolla(int n){
int a=rand()%P;
w=(a*a%P-n+P)%P;
while(check((a*a-n+P)%P)){
a=rand()%P;
w=(a*a%P-n+P)%P;
}
node res={a,1};
int k=(P+1)/2;
node base={1,0};
while(k){
if(k&1) base=base*res;
res=res*res;
k>>=1;
}
int x1=base.x,x2=(-x1+P)%P;
return min(x1,x2);
}
void SQRT(int g[],int f[],int n){
if(n==1){
g[0]=Cipolla(f[0]);
return;
}
SQRT(g,f,(n+1)>>1);
for(int i=0;i<2*n;i++) d[i]=tmp_b1[i]=tmp_b2[i]=0;
for(int i=0;i<n;i++) tmp_b1[i]=g[i];
INV(d,tmp_b1,n);
int len=1;
bit=0;
while(len<(n<<1)) len<<=1,bit++;
tot=len;
for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) tmp_b2[i]=f[i];
NTT(d,1); NTT(tmp_b2,1);
for(int i=0;i<tot;i++) d[i]=d[i]*tmp_b2[i]%P;
NTT(d,-1);
int inv2=qpow(2,P-2);
for(int i=0;i<tot;i++) g[i]=(g[i]+d[i])%P*inv2%P;
for(int i=n;i<tot;i++) g[i]=0;
}
void move(int f[],int g[],int t,int n){
for(int i=0;i<n;i++){
if(0<=i-t&&i-t<=n-1) g[i]=f[i-t];
else g[i]=0;
}
}
void POW(int fb[],int fa[],int n,string K){
int k1=0,k2=0,k=0;
bool large=false;
for(char ch:K){
k1=(k1*10+(ch-'0'))%P;
k2=(k2*10+(ch-'0'))%(P-1);
if(!large){
k=k*10+(ch-'0');
if(k>1e9) large=true;
}
}
int t=-1;
for(int i=0;i<n;i++){
if(fa[i]&&t==-1) t=i;
}
if(t==-1){
for(int i=0;i<n;i++) fb[i]=0;
return;
}
if((large&&t)||(k>=0&&t*k>=n)){
for(int i=0;i<n;i++) fb[i]=0;
return;
}
static int tmp_a[maxn];
move(fa,tmp_a,-t,n);
int inv_leading=qpow(fa[t],P-2);
for(int i=0;i<n-t;i++) tmp_a[i]=tmp_a[i]*inv_leading%P;
LN(tmp_a,tmp2,n-t);
for(int i=0;i<n-t;i++) tmp_a[i]=tmp2[i]*k1%P;
EXP(fb,tmp_a,n-t);
int coeff=qpow(fa[t],k2);
for(int i=0;i<n-t;i++) fb[i]=fb[i]*coeff%P;
fill(tmp_a,tmp_a+t*k,0);
move(fb,tmp_a,t*k,n);
for(int i=0;i<n;i++) fb[i]=tmp_a[i];
}
void DIV(int q[],int r[],int f[],int g[],int n,int m){
for(int i=0;i<=n;i++) fr[i]=f[n-i];
for(int i=0;i<=m;i++) gr[i]=g[m-i];
INV(inv_gr,gr,n-m+1);
MUL(fr,inv_gr,qr,n+1,n-m+1);
for(int i=0;i<=n-m;i++) q[i]=qr[n-m-i];
MUL(q,g,tmpf,n-m+1,m+1);
for(int i=0;i<m;i++) r[i]=(f[i]-tmpf[i]+P)%P;
for(int i=m;i<n;i++) r[i]=0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
return 0;
}
:::
参考文献
attack-【模板】多项式乘法(NTT)
Prean - 模板 P4245 题解
wsy_I 的 CRT 学习笔记(不方便公开)
KAMIYA_KINA - P4238 【模板】多项式乘法逆 题解
Kewth - 题解 P5491 【模板】二次剩余