简介莫队算法

· · 个人记录

莫队算法是国家队队长莫涛发明的玄学算法。
人们称这个算法为「优雅的暴力」,又有传说「莫队算法可以解决一切修改查询问题」......

好吧进入正题,我们来讲讲这个神奇的莫队算法。

\text{Part1——普通莫队}

普通莫队算法适用于只有询问操作的问题。
相关的最著名的题目就数 P1494 了。

题目描述: 打死不点链接

我们先来搞一个比较「暴力」的做法。
对于每次询问 [l,r],我们可以记录一个 s 数组,其中 s_i 表示颜色 i 在区间 [l,r] 中出现的次数,a_i 表示第 i 个袜子的颜色。

那么一共就有 \left(\sum\limits_{i=l}^rs_i^2 \right)-(r-l+1) 种情况可以取到相同的颜色。

因为一个位置不能取两次,所以情况数要减去区间大小。
同样地,也可以推出来随便取时,总共有 (r-l+1)(r-l) 情况。 把这两个算比值就是答案了。

为了以后表述方便,我们设 \sum\limits_{i=l}^rs_i^2 q

重点来了:

假设我们已经算出区间 [l,r]s 数组和 q,现在要算区间 [l',r'] 了。那么我们不难想到维护以下四种操作:

计算 [l+1,r],[l-1,r],[l,r-1],[l,r+1] 四种情况下的结果,只要不断地调用这四种操作,就可以实现计算区间 [l',r'] 了。

这里只说明对于 [l+1,r] 的计算,只要明白了这个,另外几种也都能明白。

对于 s 的维护很简单,只需要将 s_{a_l} 减一即可。
而对于 q 的值,我们发现 a_lq 的贡献从s_{a_l}^2 变成了 (s_{a_l}-1)^2

于是我们就能写一个函数来维护更新操作:

void update(int i,int t){
    q -= sum[a[i]]*sum[a[i]];
    sum[a[i]] += t;
    q += sum[a[i]]*sum[a[i]];
}

其中 i 表示要更新的位置,t 表示增还是减。
然后那四种操作就能轻松搞定了:

update(l,-1),++l;
update(l-1,1),--l;
update(r+1,1),++r;
update(r,-1),--r;

这四行代码依次对应的操作为:
计算 [l+1,r],[l-1,r],[l,r+1],[l,r-1]

但是这样还不够。
不难想到,对于每一次操作,更新操作的总代价最坏情况为\text O(n) ,总复杂度还会被卡到 \text O(n^2),同时还会带着一大堆常数。

那这个算法的意义何在呢?
别忘了,这个题只有查询操作,我们大可以离线处理处每一次询问的答案,然后输出啊!
对于每一个询问,我们可以用这样的一个 \text{struct} 来存储:

struct query{
    int l,r,id;
    long long A,B;
    query(int l=0,int r=0,int id=0):l(l),r(r),id(id){}
};

其中 A,B 表示这次询问答案的分子和分母。

接下来可以对数组分段,我们可以分 \sqrt n 块,然后每一块的大小为 \sqrt n
这样有什么用呢?

对询问重新排列时,就可以根据 lr 所属的块来排序了。这里以 be_i 表示 i 位置在哪一块。
a,b 两个询问的左端点在同一块,按右端点排序,否则按左端点排序。

bool cmp(query a,query b){
    if(be[a.l]==be[b.l]) return a.r<b.r;
    return a.l<b.l;
}

然后就可以按照排好的顺序,按照之前的「暴力」做法就行了。
可以证明,这样做的时间复杂度为 \text O(n^{\frac32})

AC代码如下:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 50003
#define ll long long
using namespace std;

struct query{
    int l,r,id;
    ll A,B;
    query(int l=0,int r=0,int id=0):l(l),r(r),id(id){}
};

int be[N],a[N];
ll seq[N],sum[N];
query qy[N];
int n,m,block;
ll q;

bool cmp(query a,query b){
    //sort的时候输入这个函数,就会按这种方式排序
    if(be[a.l]==be[b.l]) return a.r<b.r;
    return a.l<b.l;
}

ll gcd(ll a,ll b){
    if(b==0) return a;
    return gcd(b,a%b);
}

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)){
        x = (x<<3)+(x<<1)+c-'0';
        c = getchar();
    }
}

void print(ll x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

inline void update(int i,int t){
    //上述的更新操作
    q -= sum[a[i]]*sum[a[i]];
    sum[a[i]] += t;
    q += sum[a[i]]*sum[a[i]];
}

int main(){
    int l,r;
    read(n),read(m);
    block = sqrt(n); //块的大小
    for(int i=1;i<=n;++i){
        read(a[i]);
        be[i] = i/block+1;
    }
    for(int i=1;i<=m;++i){
        read(l),read(r);
        qy[i] = query(l,r,i);
    }    
    sort(qy+1,qy+1+m,cmp);
    for(int i=1;i<=n;++i)
        seq[qy[i].id] = i; //记录顺序,seq[i]记录了第i次询问在qy中的下标
    l = 1;
    r = 0;
    for(int i=1;i<=m;++i){
        while(l<qy[i].l) update(l,-1),++l;
        while(l>qy[i].l) update(l-1,1),--l;
        while(r<qy[i].r) update(r+1,1),++r;
        while(r>qy[i].r) update(r,-1),--r;
        //上述的四种操作,移动左右端点
        if(qy[i].l==qy[i].r){
            //特判左右端点重叠的部分,颜色相同概率为0
            qy[i].A = 0;
            qy[i].B = 1;
            continue;
        }
        qy[i].A = q-(qy[i].r-qy[i].l+1);
        qy[i].B = (ll)(qy[i].r-qy[i].l+1)*(qy[i].r-qy[i].l);
        ll g = gcd(qy[i].A,qy[i].B);
        //由于要求最简分数,所以要求gcd约分
        qy[i].A /= g;
        qy[i].B /= g;
    }
    for(int i=1;i<=m;++i){
        print(qy[seq[i]].A);
        putchar('/');
        print(qy[seq[i]].B);
        putchar('\n');
    }
    return 0;
}

其它例题:

P2709 小B的询问
AT987 高桥君
P1972 HH的项链

\text{Part2——带修改莫队}

例题:P1903
前面提到的莫队算法都是离线的,现在来了个修改,那可怎么办?

在这里,我们可以给每次询问再添加一个参数:t,表示这次询问在几次修改之后。
然后每次处理询问时,按照修改的记录,使时间加速或倒流,然后就能实现莫队算法的修改了。

别急,如果你现在就急着去弄,搞不好就成了 \text O(n^2) 的暴力,下面我们讲解优化的方法。

对于普通莫队,排序询问时我们以 l 为第一关键字,r 为第二关键字排序。带修改莫队也差不多,只不过多了个以 t 为第三关键字。

只有 a,b 两个询问的左右端点都在同一块时,才按 t 排序。具体方法如下:

bool cmp(query a,query b){
    if(be[a.l]==be[b.l]){
        if(be[a.r]==be[b.r]) return a.t<b.t;
        return a.r<b.r;
    }
    return a.l<b.l;
}

现在,只存下询问操作显然不够用了,还需要把修改操作也存下来。

由于是单点修改,所以我们只需要 3 个参数:
修改位置、此位置上一个状态、下一个状态。
分别用 \text{pos , last , next} 来表示。

struct change{
    int pos,last,next;
    change(int pos=0,int last=0,int next=0):pos(pos),last(last),next(next){}
};

前面还提到:在处理询问的过程中,根据这次询问时第几次修改,来加速或者倒流时间。这就需要一个修改操作的函数了。

要修改一个位置之前,需要判断要修改的位置是否在当前询问的区间内。如果在区间内,还需要调用 \text{update} 函数。(类比之前的更新操作)

对于带修改的莫队算法,分块的大小不能再取 \sqrt n 了。用 均值不等式/求导 可以证明:
在分块的大小为 n^\frac23 时,时间复杂度达到最低,为 \text O(n^\frac53)

参考代码:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<iostream>
#define N 50003
using namespace std;

struct query{
    int l,r,t,id;
    query(int l=0,int r=0,int t=0,int id=0):l(l),r(r),t(t),id(id){}
};

struct change{
    int pos,last,next;
    change(int pos=0,int last=0,int next=0):pos(pos),last(last),next(next){}
};

int n,m,k,unit,res,l = 1,r,t;
int a[N],clr[1000003],now[N],be[N],ans[N];
query q[N];
change c[N];

bool cmp(query a,query b){
    if(be[a.l]==be[b.l]){
        if(be[a.r]==be[b.r]) return a.t<b.t;
        return a.r<b.r;
    }
    return a.l<b.l;
}

void update(int i,int t){
    clr[i] += t;
    if(t>0&&clr[i]==1) ++res;
    if(t<0&&clr[i]==0) --res;
}

void modify(int i,int t){
    if(l<=i&&i<=r){
        update(t,1);
        update(a[i],-1);
    }
    a[i] = t;
}

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

int main(){
    read(n),read(m);
    unit = pow(n,0.66666666);
    for(int i=1;i<=n;++i){
        read(a[i]);
        now[i] = a[i];
        be[i] = i/unit+1;
    }
    for(int i=1;i<=m;++i){
        char op;
        int x,y;
        cin>>op;
        read(x),read(y);
        if(op=='Q'){
            ++k;
            q[k] = query(x,y,t,k);
        }else{
            ++t;
            c[t] = change(x,now[x],y);
            now[x] = y;
        }
    }
    sort(q+1,q+k+1,cmp);
    t = 0;
    for(int i=1;i<=k;++i){
        while(t<q[i].t) modify(c[t+1].pos,c[t+1].next),t++;
        while(t>q[i].t) modify(c[t].pos,c[t].last),t--;
        while(l<q[i].l) update(a[l],-1),l++;
        while(l>q[i].l) update(a[l-1],1),l--;
        while(r<q[i].r) update(a[r+1],1),r++;
        while(r>q[i].r) update(a[r],-1),r--;
        ans[q[i].id] = res;
    }
    for(int i=1;i<=k;++i){
        print(ans[i]);
        putchar('\n');
    }
    return 0;
}
\text{Part3——树上莫队}

莫队还能解决树上问题?你看这道:
P4074 糖果(狗粮)公园

莫队问题跑到树上去了,那不是爆炸吗?不过我们还是有解决办法。

还记得当时是怎么让莫队优化区间修改+查询问题的吗?就是分块+排序。现在问题跑到了树上,怎么分块呢?

我们可以考虑这样来分块:
从任意节点作为根,开始 dfs,每经过一个点就把它压进栈里。

当一个节点到栈顶的长度大于分块大小时,就把的这部分的点全部弹出,并分为一块。
对于最后剩下来的点,单独再分一块。

具体是什么样的呢?这里借用了一张图来解释:

此处块的大小为 3,按照我们刚才说的方法,就分成了这样。

具体代码实现如下:

void dfs(int u,int fa){
    int bottom = top; //top是一个全局变量,初始为0
    stack[++top] = u; //当前节点入栈
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==fa) continue;
        dfs(v,u);
        if(top-bottom>block){//栈顶到当前节点的长度大于分块大小
            ++idx;
            while(top!=bottom) be[stack[top--]] = idx; //弹出这部分元素,分为一块(be[i]表示i节点属于哪一块)
        }
    }
}

分块的问题解决了,排序又该怎么办?
跟普通序列上的莫队方法差不多:
以两端节点所属的块为第一、二关键字,以时间为第三关键字排序。

bool cmp(query a,query b){
    //u,v是询问路径的两端节点
    if(be[a.u]==be[b.u]){
        if(be[a.v]==be[b.v]) return a.t<b.t;
        return be[a.v]<be[b.v];
    }
    return be[a.u]<be[b.u];
}

最后,也是最难搞的问题:区间移动。

乍一看好像没什么难的,实际上也没什么难的。
一个节点要从 u 移到 v 的话,让它们不断地向上跳,直到跳到一起。

跳的过程中,用一个数组 \text{vis} 表示该节点在不在查询的路径上。
每到一个点,如果以前在路径上,那现在肯定就不在了;反之亦然。随着这样不断地更新经过的节点,不就得了?

结果你发现,如果移动节点时,要跨过 \text{lca}(u,v) 的话,问题就显现出来了,如下图:

当我们按上述步骤把 u 移到 u'v 移到 v' 时,奇怪的事情出现了:

那这好办,最后再更新一下这两个节点不就好了吗? 对于更新操作,可以维护一个数组 $\text{num}$,记录路径上各种糖果有多少个。当路径上增/减一种糖果时,对结果的贡献是 $\text{w}[\text{num}[i]] \times \text{v}[i]$,随便推一下就出来了。 确定分块的 dfs 可以和求 $\text{lca}$ 的预处理写在一起,减少码量。 代码如下: ```cpp #include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<cmath> #include<queue> #include<vector> #define N 100003 #define reg register #define int long long using namespace std; struct query{ int u,v,t,id; query(int u=0,int v=0,int t=0,int id=0):u(u),v(v),t(t),id(id){} }; struct change{ int u,last,next; change(int u=0,int last=0,int next=0):u(u),last(last),next(next){} }; query q[N]; change c[N]; int val[N],w[N],fa[N],depth[N],son[N],stack[N]; int sub[N],top[N],be[N],clr[N],ans[N],sum[N]; //此处的sum数组就是上述num数组 bool vis[N]; int n,m,T,tp,block,idx,res; vector<int> adj[N]; inline void read(int &x); void print(int x); void dfs1(int u,int f); void dfs2(int u,int f); //树剖2遍dfs inline int lca(int u,int v); bool cmp(query a,query b); inline void del(int u); //删除u inline void add(int u); //添加u inline void update(int u); //更新u节点 inline void modify(int u,int t); //修改 void move(int u,int v); //移动u到v signed main(){ int op,t,qc,u,v; read(n),read(m),read(T); block = pow(n,0.666666); for(reg int i=1;i<=m;++i) read(val[i]); for(reg int i=1;i<=n;++i) read(w[i]); for(reg int i=1;i<n;++i){ read(u),read(v); adj[u].push_back(v); adj[v].push_back(u); } for(reg int i=1;i<=n;++i){ read(clr[i]); stack[i] = clr[i]; } t = qc = 0; for(reg int i=1;i<=T;++i){ read(op),read(u),read(v); if(op==0){ c[++t] = change(u,stack[u],v); stack[u] = v; }else{ ++qc; q[qc] = query(u,v,t,qc); } } memset(stack,0,sizeof(stack)); dfs1(1,0); dfs2(1,1); while(tp>0) be[stack[tp--]] = idx; sort(q+1,q+1+qc,cmp); t = 0; u = 1; v = 1; update(1); //u和v初始都在1,所以要更新一下1号点 for(int i=1;i<=qc;++i){ while(t<q[i].t) modify(c[t+1].u,c[t+1].next),++t; while(t>q[i].t) modify(c[t].u,c[t].last),--t; update(lca(u,v)); if(u!=q[i].u) move(u,q[i].u),u = q[i].u; if(v!=q[i].v) move(v,q[i].v),v = q[i].v; update(lca(u,v)); ans[q[i].id] = res; } for(int i=1;i<=qc;++i){ print(ans[i]); putchar('\n'); } return 0; } inline void read(int &x){ x = 0; char c = getchar(); while(c<'0'||c>'9') c = getchar(); while(c>='0'&&c<='9'){ x = (x<<3)+(x<<1)+(c^48); c = getchar(); } } void print(int x){ if(x>9) print(x/10); putchar(x%10+'0'); } void dfs1(int u,int f){ int bt = tp; stack[++tp] = u; fa[u] = f; depth[u] = depth[f]+1; sub[u] = 1; int v,t = -1,l = adj[u].size(); for(int i=0;i<l;++i){ v = adj[u][i]; if(v==f) continue; dfs1(v,u); if(tp-bt>block){ ++idx; while(tp!=bt) be[stack[tp--]] = idx; } sub[u] += sub[v]; if(sub[v]>t){ t = sub[v]; son[u] = v; } } } void dfs2(int u,int f){ top[u] = f; if(son[u]==0) return; dfs2(son[u],f); int v,l = adj[u].size(); for(int i=0;i<l;++i){ v = adj[u][i]; if(v==fa[u]||v==son[u]) continue; dfs2(v,v); } } inline int lca(int u,int v){ while(top[u]!=top[v]){ if(depth[top[u]]<depth[top[v]]) swap(u,v); u = fa[top[u]]; } if(depth[u]<depth[v]) return u; return v; } bool cmp(query a,query b){ if(be[a.u]==be[b.u]){ if(be[a.v]==be[b.v]) return a.t<b.t; return be[a.v]<be[b.v]; } return be[a.u]<be[b.u]; } inline void add(int u){ ++sum[u]; res += w[sum[u]]*val[u]; } inline void del(int u){ res -= w[sum[u]]*val[u]; --sum[u]; } inline void update(int u){ if(vis[u]) del(clr[u]),vis[u] = false; else add(clr[u]),vis[u] = true; } inline void modify(int u,int t){ if(vis[u]){ del(clr[u]); add(t); } clr[u] = t; } void move(int u,int v){ if(depth[u]<depth[v]) swap(u,v); while(depth[u]>depth[v]){ update(u); u = fa[u]; } while(u!=v){ update(u),update(v); u = fa[u],v = fa[v]; } } ```