半在线卷积
本来想交到分治 FFT 的题解里去的,结果题解通道关了,那就不管了(
大概是计算形如
中的
与之相对的,全在线卷积就是必须得到
可以用这个东西比较方便地架构多项式板子
对 cdq 分治 NTT 的优化
现在先假设
我们知道正常的 cdq 分治 NTT 是分治左边,计算左边对右边的贡献,然后分治右边. 这样具有复杂度
注意这里
取
然而说起来容易写起来难,为了跑得够快需要精细的实现:
- B 取为
8 或16 比较快 - 小范围暴力
- 分出来的
B 个子问题的大小取为2 的幂,最后一个不整的块无所谓,这样不浪费 DFT 长度 - 公式的细节,假设我们当前分出来的子问题大小为
d ,那么第i 个块(从它0 标号)的范围为[id,(i+1)d) (对于这个区间而言),所以计算第i 个块对第j 个块的贡献时有用的g 的区间为((j-i-1)d,(j-i+1)d) ,并且每层内这个部分都是一样的所以可以在分治外部预处理g 在每一层的点值. 此外我们是做次数为2d-1 和d-1 的多项式乘法并且只取[d,2d-1] 项系数,所以可以直接做长度为 2d 的 DFT,这个循环卷积刚好溢出不到我们需要的位置.
代码相当漂亮
#define B 3
int *_f[20][1<<B],*_g[20][1<<B];
void cdq(int l,int r,int dep,int f[],const int g[],void brute(int*,const int*,int,int))
{
if(r-l+1<=32){brute(f,g,l,r);return;}//暴力是传入的,根据需要写
static int tmpf[N];
int d=1<<((dep-1)*B);//这一层的块大小
for(int i=0;;i++)
{
int L=l+i*d,R=min(r,L+d-1);//子问题区间
if(i)
{
fill(tmpf,tmpf+(d<<1),0);
for(int j=0;j<i;j++)//算贡献
for(int k=0;k<(d<<1);k++)
tmpf[k]=(tmpf[k]+1ll*_f[dep][j][k]*_g[dep][i-j][k])%mod;
IDFT(tmpf,d<<1);
for(int i=L;i<=R;i++)f[i]=qmod1(f[i]+tmpf[i-L+d]);
}
cdq(L,R,dep-1,f,g,brute);
if(R==r)return;
fill(_f[dep][i],_f[dep][i]+(d<<1),0);//_f记录点值
copy(f+L,f+R+1,_f[dep][i]);
DFT(_f[dep][i],d<<1);
}
}
void cdq_pre(int f[],const int g[],int n,void brute(int*,const int*,int,int))
{
fill(f,f+n,0);
if(n<=128){brute(f,g,0,n-1);return;}
int len=1,dep=0;while(len<n)len<<=B,++dep;len>>=B;
int *p=pool;
//clock_t begt=clock();
for(int i=1;i<=dep;i++)//枚举所有的层
{
int d=1<<((i-1)*B),tn=min((1<<B)-1,(n-1)/d);//子问题个数和子问题长度
for(int j=1;j<=tn;j++)
{
int l=(j-1)*d+1,r=min(n-1,(j+1)*d-1);//_g记录g的转移点值
_f[i][j-1]=p,p+=d<<1;_g[i][j]=p;p+=d<<1;
fill(_g[i][j],_g[i][j]+(d<<1),0);
copy(g+l,g+r+1,_g[i][j]+1);
DFT(_g[i][j],d<<1);
}
}
cdq(0,n-1,dep,f,g,brute);
}
#undef B
那剩下的东西就很简单了,我们只需要决定 brute 怎么写.
-
Exp:
F(z)=e^{G(z)}\Longrightarrow zF'(z)=zG'(z)F(z) nf_n=\sum_{i=1}^nig_if_{n-i} -
Ln:
F(z)=\ln G(z)\Longrightarrow zF'(z)G(z)=G'(z)
(其实有了这俩已经能把剩下的都表示出来了)
- Inv:
F(z)G(z)=1
-
Sqrt&Pow:
F(z)=G(z)^k\Longrightarrow zF'(z)G(z)=kzG'(z)F(z) ,不过这个东西有点麻烦,不如直接 Ln+Exp//loj 挑战多项式 #include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<ctime> #include<cstdlib> using namespace std; typedef unsigned long long ull; const int mod=998244353; const int N=5e5; int qmod1(int x){return x>=mod?x-mod:x;} int qmod2(int x){return x+(x>>31&mod);} int qpower(int a,int b){int ans=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;return ans;} pair<int,int>mul_comp(const pair<int,int>&a,const pair<int,int>&b,const int &d) { return make_pair((1ll*a.first*b.first+1ll*a.second*b.second%mod*d)%mod,(1ll*a.first*b.second+1ll*a.second*b.first)%mod); } int mod_sqrt(int x) { int a=rand()%mod,t; while(qpower(t=qmod2(1ll*a*a%mod-x),(mod-1)>>1)==1)a=rand()%mod; pair<int,int>w=make_pair(a,mod-1),ans=make_pair(1,0); int b=(mod+1)>>1; for(;b;b>>=1,w=mul_comp(w,w,t))if(b&1)ans=mul_comp(ans,w,t); return min(ans.first,mod-ans.first); } namespace Poly { int w[N],inv[N],pool[N<<4]; int limn,rev[N],lg[N]; void prework(int n) { limn=1;while(limn<n)limn<<=1; for(int i=1;i<limn;i++)rev[i]=rev[i>>1]>>1|((i&1)?(limn>>1):0); for(int i=1;i<limn;i<<=1) { w[i]=1;int omg=qpower(3,(mod-1)/(i<<1)); for(int j=1;j<i;j++)w[i+j]=1ll*w[i+j-1]*omg%mod; } for(int i=2;i<=limn;i++)lg[i]=lg[i>>1]+1; inv[1]=1;for(int i=2;i<limn;i++)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod; } void DFT(int a[],int n) { int t=lg[limn/n]; static ull tmp[N]; for(int i=0;i<n;i++)tmp[rev[i]>>t]=a[i]; for(int i=1;i<n;i<<=1) for(int j=0;j<n;j+=i<<1) for(int k=0;k<i;k++) { unsigned t=tmp[i+j+k]*w[i+k]%mod; tmp[i+j+k]=tmp[j+k]+mod-t,tmp[j+k]+=t; } for(int i=0;i<n;i++)a[i]=tmp[i]%mod; } void IDFT(int a[],int n) { reverse(a+1,a+n);DFT(a,n); int iv=mod-(mod-1)/n; for(int i=0;i<n;i++)a[i]=1ll*a[i]*iv%mod; } void print(int a[],int n) { for(int i=0;i<n;i++)cout<<a[i]<<" ";puts(""); } #define B 3 int *_f[20][1<<B],*_g[20][1<<B]; void cdq(int l,int r,int dep,int f[],const int g[],void brute(int*,const int*,int,int)) { if(r-l+1<=32){brute(f,g,l,r);return;} static int tmpf[N]; int d=1<<((dep-1)*B); for(int i=0;;i++) { int L=l+i*d,R=min(r,L+d-1); if(i) { fill(tmpf,tmpf+(d<<1),0); for(int j=0;j<i;j++) for(int k=0;k<(d<<1);k++) tmpf[k]=(tmpf[k]+1ll*_f[dep][j][k]*_g[dep][i-j][k])%mod; IDFT(tmpf,d<<1); for(int i=L;i<=R;i++)f[i]=qmod1(f[i]+tmpf[i-L+d]); } cdq(L,R,dep-1,f,g,brute); if(R==r)return; fill(_f[dep][i],_f[dep][i]+(d<<1),0); copy(f+L,f+R+1,_f[dep][i]);//fill(_f[dep][i]+R-L+1,_f[dep][i]+(d<<1),0); DFT(_f[dep][i],d<<1); } } void cdq_pre(int f[],const int g[],int n,void brute(int*,const int*,int,int)) { fill(f,f+n,0); if(n<=128){brute(f,g,0,n-1);return;} int len=1,dep=0;while(len<n)len<<=B,++dep;len>>=B; int *p=pool; //clock_t begt=clock(); for(int i=1;i<=dep;i++) { int d=1<<((i-1)*B),tn=min((1<<B)-1,(n-1)/d); for(int j=1;j<=tn;j++) { int l=(j-1)*d+1,r=min(n-1,(j+1)*d-1); _f[i][j-1]=p,p+=d<<1;_g[i][j]=p;p+=d<<1; fill(_g[i][j],_g[i][j]+(d<<1),0); copy(g+l,g+r+1,_g[i][j]+1); DFT(_g[i][j],d<<1); } } cdq(0,n-1,dep,f,g,brute); } #undef B void Der(int f[],const int g[],int n) { for(int i=0;i<n-1;i++)f[i]=1ll*(i+1)*g[i+1]%mod; f[n-1]=0; } void Int(int f[],const int g[],int n) { for(int i=1;i<n;i++)f[i]=1ll*inv[i]*g[i-1]%mod; f[0]=0; } void brute_ln(int f[],const int g[],int l,int r) { if(!l)f[l]=0;else f[l]=qmod2(1ll*l*g[l]%mod-f[l]); for(int i=l+1;i<=r;i++) { for(int j=l;j<i;j++) f[i]=(f[i]+1ll*f[j]*g[i-j])%mod; f[i]=qmod2(1ll*i*g[i]%mod-f[i]); } } void Ln(int f[],const int g[],int n) { cdq_pre(f,g,n,brute_ln); f[0]=0;for(int i=1;i<n;i++)f[i]=1ll*f[i]*inv[i]%mod; } void brute_exp(int f[],const int g[],int l,int r) { if(!l)f[l]=1;else f[l]=1ll*f[l]*inv[l]%mod; for(int i=l+1;i<=r;i++) { for(int j=l;j<i;j++)f[i]=(f[i]+1ll*f[j]*g[i-j])%mod; f[i]=1ll*f[i]*inv[i]%mod; } } void Exp(int f[],const int g[],int n) { static int tmpg[N]; for(int i=0;i<n;i++)tmpg[i]=1ll*i*g[i]%mod; cdq_pre(f,tmpg,n,brute_exp); } void brute_inv(int f[],const int g[],int l,int r) { if(!l)f[l]=1; for(int i=l+1;i<=r;i++) for(int j=l;j<i;j++)f[i]=(f[i]+1ll*f[j]*g[i-j])%mod; } void Inv(int f[],const int g[],int n) { static int tmpg[N]; int iv=qpower(g[0],mod-2); for(int i=1;i<n;i++)tmpg[i]=1ll*(mod-g[i])*iv%mod; cdq_pre(f,tmpg,n,brute_inv); for(int i=0;i<n;i++)f[i]=1ll*f[i]*iv%mod; } void Pow(int f[],const int g[],int n,int K) { Ln(f,g,n); for(int i=0;i<n;i++)f[i]=1ll*f[i]*K%mod; static int tmpf[N]; copy(f,f+n,tmpf); Exp(f,tmpf,n); } void Sqrt(int f[],const int g[],int n) { int t=g[0],iv=qpower(t,mod-2); static int tmpg[N]; for(int i=0;i<n;i++)tmpg[i]=1ll*g[i]*iv%mod; Pow(f,tmpg,n,(mod+1)>>1); t=mod_sqrt(t); for(int i=0;i<n;i++)f[i]=1ll*f[i]*t%mod; } } int n,f[N],g[N]; int getin() { int x=0;char ch=getchar(); while(ch<'0'||ch>'9')ch=getchar(); while(ch>='0'&&ch<='9')x=x*10+ch-48,ch=getchar(); return x; } int main() { // freopen("data.txt","r",stdin); n=getin()+1;int K=getin();Poly::prework(n); for(int i=0;i<n;i++)g[i]=getin(); static int tmpg[N]; copy(g,g+n,tmpg); Poly::Sqrt(f,g,n);swap(f,g); Poly::Inv(f,g,n);swap(f,g); Poly::Int(f,g,n);swap(f,g); Poly::Exp(f,g,n);swap(f,g); for(int i=0;i<n;i++)g[i]=qmod2(tmpg[i]-g[i]); g[0]=1;Poly::Ln(f,g,n);swap(f,g); g[0]=1;Poly::Pow(f,g,n,K);swap(f,g); Poly::Der(f,g,n);--n; for(int i=0;i<n;i++)printf("%d ",f[i]); }(在 luogu 抢了几个最优解,不过 loj 上有更快的)
不过显然这个代码并不是最终的板子(我暂时也不打算写),它甚至连乘法都没封装进去( 最终的版本大概还是要写成 vector 形式,比较易用.
半在线卷积
上面的问题里我们的
例题
我们先快进到无标号有根树的生成函数方程
那么
令
显然
边界为
先考虑正常的二叉分治(左区间长度为区间长度的 highbit),那么算贡献的时候需要由
-
当
l=0 时,只计算f_{l,mid}\ast g_{l,mid} 对f_{mid+1,r} 的贡献 -
当
l\neq 0 时,发现前面漏掉的贡献是f_{l,mid}\ast g_{r-l} 和g_{l,mid}\ast f_{r-l} ,加上即可.
而这个东西稍微改改就能变成我们的
只有
-
当
l=0 时,假设当前已经算完了前i-1 个子问题,那么第0 个子问题对第i 个子问题的贡献只计算f_{0,d-1}\ast g_{(i-1)d+1,id-1} 对f_{id,(i+1)d-1} 的贡献 -
当
l\neq 0 时,第i 个子问题除了正常的用前面的子问题的f 转移之外还要加上从前面的子问题的g 转移过来的结果.
可能代码会更清楚一些(
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<ctime>
#include<cstdlib>
using namespace std;
typedef unsigned long long ull;
const int mod=998244353;
const int N=5e5;
int qmod1(int x){return x>=mod?x-mod:x;}
int qmod2(int x){return x+(x>>31&mod);}
int n,f[N],g[N];
namespace Poly
{
int w[N],inv[N],pool[N<<4];
int limn,rev[N],lg[N];
int qpower(int a,int b){int ans=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;return ans;}
void prework(int n)
{
limn=1;while(limn<n)limn<<=1;
for(int i=1;i<limn;i++)rev[i]=rev[i>>1]>>1|((i&1)?(limn>>1):0);
for(int i=1;i<limn;i<<=1)
{
w[i]=1;int omg=qpower(3,(mod-1)/(i<<1));
for(int j=1;j<i;j++)w[i+j]=1ll*w[i+j-1]*omg%mod;
}
for(int i=2;i<=limn;i++)lg[i]=lg[i>>1]+1;
inv[1]=1;for(int i=2;i<limn;i++)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
}
void DFT(int a[],int n)
{
int t=lg[limn/n];
static ull tmp[N];
for(int i=0;i<n;i++)tmp[rev[i]>>t]=a[i];
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=i<<1)
for(int k=0;k<i;k++)
{
unsigned t=tmp[i+j+k]*w[i+k]%mod;
tmp[i+j+k]=tmp[j+k]+mod-t,tmp[j+k]+=t;
}
for(int i=0;i<n;i++)a[i]=tmp[i]%mod;
}
void IDFT(int a[],int n)
{
reverse(a+1,a+n);DFT(a,n);
int iv=mod-(mod-1)/n;
for(int i=0;i<n;i++)a[i]=1ll*a[i]*iv%mod;
}
void print(int a[],int n)
{
for(int i=0;i<n;i++)cout<<a[i]<<" ";puts("");
}
#define B 3
int *_f0[20][1<<B],*_g0[20][1<<B],*_f[20][1<<B],*_g[20][1<<B];
//_f0,_g0 和上个问题中的 _g 的作用类似,_f,_g 和上个问题中的 _f 作用类似
void calc0(const int f[],const int g[],int i,int dep)
{
int d=1<<((dep-1)*B);
int l=(i-1)*d+1,r=(i+1)*d-1;
fill(_g0[dep][i],_g0[dep][i]+(d<<1),0),fill(_f0[dep][i],_f0[dep][i]+(d<<1),0);
copy(g+l,g+r+1,_g0[dep][i]+1),copy(f+l,f+r+1,_f0[dep][i]+1);
DFT(_g0[dep][i],d<<1),DFT(_f0[dep][i],d<<1);
}
void cdq(int l,int r,int dep,int f[],int g[],void brute(int*,int*,int,int))
{
if(r-l+1<=1){brute(f,g,l,r);return;}
static int tmpf[N];
int d=1<<((dep-1)*B);
for(int i=0;;i++)
{
int L=l+i*d,R=min(r,L+d-1);
if(i)
{
fill(tmpf,tmpf+(d<<1),0);
for(int j=l?0:1;j<i;j++)
for(int k=0;k<(d<<1);k++)
tmpf[k]=(tmpf[k]+1ll*_f[dep][j][k]*_g0[dep][i-j][k])%mod;
if(l)//l!=0还需要加上的贡献
{
for(int j=0;j<i;j++)
for(int k=0;k<(d<<1);k++)
tmpf[k]=(tmpf[k]+1ll*_g[dep][j][k]*_f0[dep][i-j][k])%mod;
}
IDFT(tmpf,d<<1);
for(int i=L;i<=R;i++)f[i]=qmod1(f[i]+tmpf[i-L+d]);
}
cdq(L,R,dep-1,f,g,brute);
if(!l&&i)calc0(f,g,i,dep);//l=0的话算一下_f0和_g0
if(R==r)return;
fill(_f[dep][i],_f[dep][i]+(d<<1),0);copy(f+L,f+R+1,_f[dep][i]);DFT(_f[dep][i],d<<1);
fill(_g[dep][i],_g[dep][i]+(d<<1),0);copy(g+L,g+R+1,_g[dep][i]);DFT(_g[dep][i],d<<1);
if(!l)//计算第0个子问题的部分贡献
{
for(int k=0;k<(d<<1);k++)tmpf[k]=1ll*_f[dep][0][k]*_g[dep][i][k]%mod;
IDFT(tmpf,d<<1);
for(int i=d;i<(d<<1);i++)f[L+i]=qmod1(f[L+i]+tmpf[i]);
}
}
}
void cdq_pre(int f[],int g[],int n,void brute(int*,int*,int,int))
{
fill(f,f+n,0);
// if(n<=128){brute(f,g,0,n-1);return;}
int len=1,dep=0;while(len<n)len<<=B,++dep;len>>=B;
int *p=pool;
//clock_t begt=clock();
for(int i=1;i<=dep;i++)
{
int d=1<<((i-1)*B),tn=min((1<<B)-1,(n-1)/d);
for(int j=1;j<=tn;j++)
{
_f[i][j-1]=p,p+=d<<1;_g0[i][j]=p;p+=d<<1;
_f0[i][j]=p,p+=d<<1;_g[i][j-1]=p,p+=d<<1;
}
}
cdq(0,n-1,dep,f,g,brute);
}
#undef B
}
void update(int x)
{
int t=1ll*x*f[x]%mod;
for(int i=x;i<n;i+=x)
g[i]=qmod1(g[i]+t);
}
void brute(int f[],int g[],int l,int r)
{
if(!l)f[l]=g[l]=0,++l;
for(int i=l;i<=r;i++)
{
if(i==1){f[i]=1;update(i);continue;}
for(int j=l;j<i;j++)f[i]=(f[i]+1ll*f[j]*g[i-j]+1ll*g[j]*f[i-j])%mod;
f[i]=1ll*f[i]*Poly::inv[i-1]%mod;update(i);
}
}
int getin()
{
int x=0;char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9')x=x*10+ch-48,ch=getchar();
return x;
}
int main()
{
// freopen("data.txt","r",stdin);
n=getin()+1;Poly::prework(n);
Poly::cdq_pre(f,g,n,brute);
// for(int i=0;i<n;i++)cout<<f[i]<<" ";puts("");
--n;
int ans=f[n];
for(int i=(n>>1)+1;i<n;i++)ans=qmod2(ans-1ll*f[i]*f[n-i]%mod);
if(~n&1)ans=qmod2(ans-1ll*f[n>>1]*(f[n>>1]-1)/2%mod);
cout<<ans<<endl;
}
不过