半在线卷积

· · 个人记录

本来想交到分治 FFT 的题解里去的,结果题解通道关了,那就不管了(

大概是计算形如

f_n=c_i\sum_{i=1}^nf_ig_{n-i}

中的 f,其中 c,g 不一定是直接给出的,可能与 f 有关(比如随便编一个 f 的前缀异或和啥的)

与之相对的,全在线卷积就是必须得到 f_{0\cdots n-1} 之后才能得到 g_nc_n. 由于我也不懂的原因,这两个问题复杂度相同.

可以用这个东西比较方便地架构多项式板子

对 cdq 分治 NTT 的优化

现在先假设 g 全部已知,这也是架构多项式板子时会遇到的情况.

我们知道正常的 cdq 分治 NTT 是分治左边,计算左边对右边的贡献,然后分治右边. 这样具有复杂度 O(n\log^2 n). 不过我们可以试着分成 B 个子问题,然后计算两两之间的贡献. 乍一看需要 \Theta(B^2)\Theta(\dfrac{n}{B}\log \dfrac{n}{B}) 的卷积,但注意到我们可以记录下来点值,这样计算一次贡献就是 \Theta(\dfrac{n}{B}) 的了,也就是我们的复杂度为

T(n)=BT(n/B)+\Theta(nB+n\log\frac{n}{B})

注意这里 B 相对于 n 不是常数,所以不能用主定理分析复杂度.

B=\Theta(\log n),得到复杂度 T(n)=\Theta\left(\dfrac{n\log^2n}{\log\log n}\right)

然而说起来容易写起来难,为了跑得够快需要精细的实现:

  1. B 取为 816 比较快
  2. 小范围暴力
  3. 分出来的 B 个子问题的大小取为 2 的幂,最后一个不整的块无所谓,这样不浪费 DFT 长度
  4. 公式的细节,假设我们当前分出来的子问题大小为 d,那么第 i 个块(从它 0 标号)的范围为 [id,(i+1)d)(对于这个区间而言),所以计算第 i 个块对第 j 个块的贡献时有用的 g 的区间为 ((j-i-1)d,(j-i+1)d),并且每层内这个部分都是一样的所以可以在分治外部预处理 g 在每一层的点值. 此外我们是做次数为 2d-1d-1 的多项式乘法并且只取 [d,2d-1] 项系数,所以可以直接做长度为 2d 的 DFT,这个循环卷积刚好溢出不到我们需要的位置.

代码相当漂亮

    #define B 3
    int *_f[20][1<<B],*_g[20][1<<B];
    void cdq(int l,int r,int dep,int f[],const int g[],void brute(int*,const int*,int,int))
    {
        if(r-l+1<=32){brute(f,g,l,r);return;}//暴力是传入的,根据需要写
        static int tmpf[N];
        int d=1<<((dep-1)*B);//这一层的块大小
        for(int i=0;;i++)
        {
            int L=l+i*d,R=min(r,L+d-1);//子问题区间
            if(i)
            {
                fill(tmpf,tmpf+(d<<1),0);
                for(int j=0;j<i;j++)//算贡献
                    for(int k=0;k<(d<<1);k++)
                        tmpf[k]=(tmpf[k]+1ll*_f[dep][j][k]*_g[dep][i-j][k])%mod;
                IDFT(tmpf,d<<1);
                for(int i=L;i<=R;i++)f[i]=qmod1(f[i]+tmpf[i-L+d]);
            }
            cdq(L,R,dep-1,f,g,brute);
            if(R==r)return;
            fill(_f[dep][i],_f[dep][i]+(d<<1),0);//_f记录点值
            copy(f+L,f+R+1,_f[dep][i]);
            DFT(_f[dep][i],d<<1);
        }
    }
    void cdq_pre(int f[],const int g[],int n,void brute(int*,const int*,int,int))
    {
        fill(f,f+n,0);
        if(n<=128){brute(f,g,0,n-1);return;}
        int len=1,dep=0;while(len<n)len<<=B,++dep;len>>=B;
        int *p=pool;
        //clock_t begt=clock();
        for(int i=1;i<=dep;i++)//枚举所有的层
        {
            int d=1<<((i-1)*B),tn=min((1<<B)-1,(n-1)/d);//子问题个数和子问题长度
            for(int j=1;j<=tn;j++)
            {
                int l=(j-1)*d+1,r=min(n-1,(j+1)*d-1);//_g记录g的转移点值
                _f[i][j-1]=p,p+=d<<1;_g[i][j]=p;p+=d<<1;
                fill(_g[i][j],_g[i][j]+(d<<1),0);
                copy(g+l,g+r+1,_g[i][j]+1);
                DFT(_g[i][j],d<<1);
            }
        }
        cdq(0,n-1,dep,f,g,brute);
    }
    #undef B

那剩下的东西就很简单了,我们只需要决定 brute 怎么写.

nf_n=ng_n-\sum_{i=1}^n(n-i)f_{n-i}g_i

(其实有了这俩已经能把剩下的都表示出来了)

g_0f_n=-\sum_{i=1}^ng_if_{n-i}

不过显然这个代码并不是最终的板子(我暂时也不打算写),它甚至连乘法都没封装进去( 最终的版本大概还是要写成 vector 形式,比较易用.

半在线卷积

上面的问题里我们的 g 是已知的,但现在 g 可能不是已知的.

例题

我们先快进到无标号有根树的生成函数方程

F(z)=z\exp\left(\sum_{i\geq 1}\frac{1}{i}F(z^i)\right)

那么

zF'(z)=F(z)+F(z)\sum_{i\geq 1}z^iF(z^i)

G(z)=\sum_{i\geq 1}z^iF(z^i)=\sum_{i\geq 1}z^i\sum_{d\mid i}df_d

显然 g 的系数可以用总共 \Theta(n\ln n) 的时间从 f 得到. 现在的递推式就是

(n-1)f_n=\sum_{i=1}^{n-1}g_if_{n-i}

边界为 f_1=1,g_0=f_0=0.

先考虑正常的二叉分治(左区间长度为区间长度的 highbit),那么算贡献的时候需要由 f_{l,mid} 卷上 g_{1,r-l} 得到对 f_{mid+1,r} 的贡献. 但是当 l=0 的时候 g_{1,r-l} 还没有计算出来,其他情况下一定有 r-l<l. 因此我们分类计算:

而这个东西稍微改改就能变成我们的 B 叉分治:

只有 l=0 时第 0 个子问题对其他子问题的贡献无法计算.

可能代码会更清楚一些(

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<ctime>
#include<cstdlib>
using namespace std;
typedef unsigned long long ull;
const int mod=998244353;
const int N=5e5;
int qmod1(int x){return x>=mod?x-mod:x;}
int qmod2(int x){return x+(x>>31&mod);}
int n,f[N],g[N];
namespace Poly
{
    int w[N],inv[N],pool[N<<4];
    int limn,rev[N],lg[N];
    int qpower(int a,int b){int ans=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;return ans;}
    void prework(int n)
    {
        limn=1;while(limn<n)limn<<=1;
        for(int i=1;i<limn;i++)rev[i]=rev[i>>1]>>1|((i&1)?(limn>>1):0);
        for(int i=1;i<limn;i<<=1)
        {
            w[i]=1;int omg=qpower(3,(mod-1)/(i<<1));
            for(int j=1;j<i;j++)w[i+j]=1ll*w[i+j-1]*omg%mod;
        }
        for(int i=2;i<=limn;i++)lg[i]=lg[i>>1]+1;
        inv[1]=1;for(int i=2;i<limn;i++)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
    }
    void DFT(int a[],int n)
    {
        int t=lg[limn/n];
        static ull tmp[N];
        for(int i=0;i<n;i++)tmp[rev[i]>>t]=a[i];
        for(int i=1;i<n;i<<=1)
            for(int j=0;j<n;j+=i<<1)
                for(int k=0;k<i;k++)
                {
                    unsigned t=tmp[i+j+k]*w[i+k]%mod;
                    tmp[i+j+k]=tmp[j+k]+mod-t,tmp[j+k]+=t;
                }
        for(int i=0;i<n;i++)a[i]=tmp[i]%mod;
    }
    void IDFT(int a[],int n)
    {
        reverse(a+1,a+n);DFT(a,n);
        int iv=mod-(mod-1)/n;
        for(int i=0;i<n;i++)a[i]=1ll*a[i]*iv%mod;
    }
    void print(int a[],int n)
    {
        for(int i=0;i<n;i++)cout<<a[i]<<" ";puts("");
    }
    #define B 3
    int *_f0[20][1<<B],*_g0[20][1<<B],*_f[20][1<<B],*_g[20][1<<B];
    //_f0,_g0 和上个问题中的 _g 的作用类似,_f,_g 和上个问题中的 _f 作用类似
    void calc0(const int f[],const int g[],int i,int dep)
    {
        int d=1<<((dep-1)*B);
        int l=(i-1)*d+1,r=(i+1)*d-1;
        fill(_g0[dep][i],_g0[dep][i]+(d<<1),0),fill(_f0[dep][i],_f0[dep][i]+(d<<1),0);
        copy(g+l,g+r+1,_g0[dep][i]+1),copy(f+l,f+r+1,_f0[dep][i]+1);
        DFT(_g0[dep][i],d<<1),DFT(_f0[dep][i],d<<1);
    }
    void cdq(int l,int r,int dep,int f[],int g[],void brute(int*,int*,int,int))
    {
        if(r-l+1<=1){brute(f,g,l,r);return;}
        static int tmpf[N];
        int d=1<<((dep-1)*B);
        for(int i=0;;i++)
        {
            int L=l+i*d,R=min(r,L+d-1);
            if(i)
            {
                fill(tmpf,tmpf+(d<<1),0);
                for(int j=l?0:1;j<i;j++)
                    for(int k=0;k<(d<<1);k++)
                        tmpf[k]=(tmpf[k]+1ll*_f[dep][j][k]*_g0[dep][i-j][k])%mod;
                if(l)//l!=0还需要加上的贡献
                {
                    for(int j=0;j<i;j++)
                        for(int k=0;k<(d<<1);k++)
                            tmpf[k]=(tmpf[k]+1ll*_g[dep][j][k]*_f0[dep][i-j][k])%mod;
                }
                IDFT(tmpf,d<<1);
                for(int i=L;i<=R;i++)f[i]=qmod1(f[i]+tmpf[i-L+d]);
            }
            cdq(L,R,dep-1,f,g,brute);
            if(!l&&i)calc0(f,g,i,dep);//l=0的话算一下_f0和_g0
            if(R==r)return;                
            fill(_f[dep][i],_f[dep][i]+(d<<1),0);copy(f+L,f+R+1,_f[dep][i]);DFT(_f[dep][i],d<<1);
            fill(_g[dep][i],_g[dep][i]+(d<<1),0);copy(g+L,g+R+1,_g[dep][i]);DFT(_g[dep][i],d<<1);
            if(!l)//计算第0个子问题的部分贡献
            {
                for(int k=0;k<(d<<1);k++)tmpf[k]=1ll*_f[dep][0][k]*_g[dep][i][k]%mod;
                IDFT(tmpf,d<<1);
                for(int i=d;i<(d<<1);i++)f[L+i]=qmod1(f[L+i]+tmpf[i]);
            }
        }
    }
    void cdq_pre(int f[],int g[],int n,void brute(int*,int*,int,int))
    {
        fill(f,f+n,0);
//        if(n<=128){brute(f,g,0,n-1);return;}
        int len=1,dep=0;while(len<n)len<<=B,++dep;len>>=B;
        int *p=pool;
        //clock_t begt=clock();
        for(int i=1;i<=dep;i++)
        {
            int d=1<<((i-1)*B),tn=min((1<<B)-1,(n-1)/d);
            for(int j=1;j<=tn;j++)
            {
                _f[i][j-1]=p,p+=d<<1;_g0[i][j]=p;p+=d<<1;
                _f0[i][j]=p,p+=d<<1;_g[i][j-1]=p,p+=d<<1;
            }
        }
        cdq(0,n-1,dep,f,g,brute);
    }
    #undef B
}

void update(int x)
{
    int t=1ll*x*f[x]%mod;
    for(int i=x;i<n;i+=x)
        g[i]=qmod1(g[i]+t);
}
void brute(int f[],int g[],int l,int r)
{
    if(!l)f[l]=g[l]=0,++l;
    for(int i=l;i<=r;i++)
    {
        if(i==1){f[i]=1;update(i);continue;}
        for(int j=l;j<i;j++)f[i]=(f[i]+1ll*f[j]*g[i-j]+1ll*g[j]*f[i-j])%mod;
        f[i]=1ll*f[i]*Poly::inv[i-1]%mod;update(i);
    }
}
int getin()
{
    int x=0;char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9')x=x*10+ch-48,ch=getchar();
    return x;
}
int main()
{
//    freopen("data.txt","r",stdin);
    n=getin()+1;Poly::prework(n);
    Poly::cdq_pre(f,g,n,brute);
//    for(int i=0;i<n;i++)cout<<f[i]<<" ";puts("");
    --n;
    int ans=f[n];
    for(int i=(n>>1)+1;i<n;i++)ans=qmod2(ans-1ll*f[i]*f[n-i]%mod);
    if(~n&1)ans=qmod2(ans-1ll*f[n>>1]*(f[n>>1]-1)/2%mod);
    cout<<ans<<endl;
}

不过 B=1 也就是正常二叉分治的时候这个代码就比 rk2 快得多了(