杜教筛学习笔记

· · 个人记录

本来想放在 数论初步 里的,但是那篇文章实在太长了,编辑的时候标签页崩溃了。。。

上面的文章为前置知识。

参考资料:

OI Wiki

模板题题解

理论基础

杜教筛是一种可以在 \mathcal O(n^{\frac 23}) 的时间内算出一个积性函数的前缀和的算法。

数论基础

设有一个积性函数 \mathbf f,它的前缀和为 \mathbf S(n)=\sum_{i=1}^n\mathbf f(i)。再找一个积性函数 \mathbf g,考虑它们狄利克雷卷积的前缀和,可得:

\begin{aligned} \sum_{i=1}^n\mathbf{(f*g)}(i)&=\sum_{i=1}^n\sum_{d\mid i}\mathbf f(\frac id)\mathbf g(d)\\ &=\sum_{d=1}^n\mathbf g(d)\sum_{i=1}^{\lfloor\frac nd\rfloor}\mathbf f(i)\\ &=\sum_{d=1}^n\mathbf g(d)\mathbf S(\lfloor\frac nd\rfloor) \end{aligned}

则有

\begin{aligned} \mathbf g(1)\mathbf S(n)&=\sum_{d=1}^n\mathbf g(d)\mathbf S(\lfloor\frac nd\rfloor)-\sum_{d=2}^n\mathbf g(d)\mathbf S(\lfloor\frac nd\rfloor)\\ &=\sum_{i=1}^n\mathbf{(f*g)}(i)-\sum_{d=2}^n\mathbf g(d)\mathbf S(\lfloor\frac nd\rfloor) \end{aligned}

由于 \mathbf g 是积性函数,所以 \mathbf g(1)=1。因此

\mathbf S(n)=\sum_{i=1}^n\mathbf{(f*g)}(i)-\sum_{d=2}^n\mathbf g(d)\mathbf S(\lfloor\frac nd\rfloor)

如果我们可以快速计算 \mathbf {f*g}\mathbf g 的前缀和,就可以利用数论分块递归计算 \mathbf S。另外,为了优化时间复杂度,可以 \mathcal O(n^{\frac 23}) 预处理出前 n^{\frac 23} 个数的答案,之后再用杜教筛。然而我不会证明时间复杂度(不会微积分太不方便了)。 现在会了。见下文。

对于记录答案,可以用 map,但更好的做法是观察性质。我们知道

\left\lfloor\frac{\lfloor\frac xy\rfloor}z\right\rfloor=\left\lfloor\frac x{yz}\right\rfloor

证明:设 x=ay+b(b<y)a=cz+d(d<z),则左式显然为 c,而 x=(cz+d)y+b=cyz+dy+b,又因为 dy+b\le (z-1)y+y-1<yz,所以右式也等于 c,得证。

所以每次递归计算的一定是某个 \mathbf S(\lfloor\dfrac nx\rfloor)。由于当 x>n^{\frac 13} 时会直接获取预处理出的答案,所以需要记忆化的数一定少于 n^{\frac 13},因此可以开一个哈希表记录答案。

所以我们要设计一个哈希函数 H,保证对于任意的 \lfloor\dfrac nx\rfloor\ne \lfloor\dfrac ny\rfloor\land x,y\le n^{\frac 13},都有 H(\lfloor\dfrac nx\rfloor)\ne H(\lfloor\dfrac ny\rfloor)。可以取 H(i)=\lfloor\dfrac ni\rfloor。因为

k=H(\lfloor\dfrac nx\rfloor)=\left\lfloor\frac n{\lfloor\frac nx\rfloor}\right\rfloor

是使得取值为 \lfloor\dfrac nk\rfloor=\lfloor\dfrac nx\rfloor 的最大的 k(证明见前置知识中的数论分块部分),所以 \lfloor\dfrac nx\rfloorH(\lfloor\dfrac nx\rfloor) 构成了一个双射。

时间复杂度证明

我们已经证明了 \mathbf S 的自变量取值一定是某个 \lfloor\dfrac nx\rfloor,所以只有 O(\sqrt n) 种状态。最多的时候状态集合为 \{x\mid x\in\mathbb N\cap[1,\sqrt n]\}\cap\{\lfloor\dfrac nx\rfloor\mid x\in\mathbb N\cap[1,\sqrt n]\}。而计算一个状态 x 的时间复杂度为 O(\sqrt x)。所以如果不预处理,总的时间复杂度为

\begin{aligned} O(\sum_{i=1}^{\sqrt n}\sqrt i+\sum_{i=1}^{\sqrt n}\sqrt{\lfloor\frac ni\rfloor})&=O(\sum_{i=1}^{\sqrt n}\sqrt{\lfloor\frac ni\rfloor})\\ &=O(\int_1^{\sqrt n}\sqrt{\frac nx}\mathrm dx)\\ &=O(\sqrt n\int_1^{\sqrt n}\frac 1{\sqrt x}\mathrm dx)\\ &=O(n^{\frac 12}\int_1^{\sqrt n}x^{-\frac 12}\mathrm dx)\\ &=O(n^{\frac 34}) \end{aligned}

如果线性预处理出 [1,n^a] 的答案且 a>\dfrac 12,则时间复杂度为

\begin{aligned} O(n^a+\sum_{i=1}^{n^{1-a}}\sqrt{\lfloor\frac ni\rfloor})&=O(n^a+\int_1^{n^{1-a}}\sqrt{\frac nx}\mathrm dx)\\ &=O(n^a+n^{\frac 12}\int_1^{n^{1-a}}x^{-\frac 12}\mathrm dx)\\ &=O(n^a+n^{1-\frac{a}{2}}) \end{aligned}

可以看出,a\dfrac 23 时最优,时间复杂度为 O(n^{\frac 23})

代码实现

模板题,提交记录。

知道杜教筛的理论后,只要知道 \varphi*\mathbf 1=\mathbf{id}\mu*\mathbf 1=\varepsilon以及数据的毒瘤程度,就可以写出代码了:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const int N=1.7e6,sN=1300;
int _n,v[N],phi[N],mu[N];
ll preP[N],preM[N],sumP[sN],sumM[sN];
vector<int>pr;
inline void init(){
    preP[1]=phi[1]=1;preM[1]=mu[1]=1;
    for(int i=2;i<N;i++){
        if(!v[i])pr.push_back(v[i]=i),mu[i]=-1,phi[i]=i-1;
        for(int p:pr){
            if(p>v[i]||p>(N-1)/i)break;
            v[i*p]=p;if(v[i]==p)phi[i*p]=phi[i]*p,mu[i*p]=0;
            else phi[i*p]=phi[i]*phi[p],mu[i*p]=mu[i]*mu[p];
        }
        preP[i]=preP[i-1]+phi[i];preM[i]=preM[i-1]+mu[i];
    }
}
ll getSumPhi(int n){
    if(n<N)return preP[n];
    if(sumP[_n/n])return sumP[_n/n];
    ll ret=((ll)n*(n+1))>>1;for(int l=2,r=1;l<=n;l=r+1){
        r=n/(n/l);
        ret-=getSumPhi(n/l)*(r-l+1);
    }
    return sumP[_n/n]=ret;
}
ll getSumMu(int n){
    if(n<N)return preM[n];
    if(sumM[_n/n])return sumM[_n/n];
    ll ret=1;for(int l=2,r=1;l<=n;l=r+1){
        r=n/(n/l);
        ret-=getSumMu(n/l)*(r-l+1);
    }
    return sumM[_n/n]=ret;
}
inline void ct(){
    scanf("%d",&_n);
    if(_n==2147483647){// 当 n=(1<<31)-1 时,int 会爆,所以要特判
        printf("%lld %d\n",1401784457568941916,9569);
        return;
    }
    memset(sumP,0,sizeof(sumP));
    memset(sumM,0,sizeof(sumM));
    printf("%lld %lld\n",getSumPhi(_n),getSumMu(_n));
}
int main(){
    init();int T;scanf("%d",&T);
    while(T--)ct();
    return 0;
}