CDQ 分治详解

· · 算法·理论

CDQ 分治

CDQ 分治是一种离线的分治算法,和 dp 一样,是一种解决问题的思想,主要思想为用一个子问题来计算另一个子问题。CDQ 分治可以套很多层,每多一层时间复杂度就要乘上一个 \log n (我以后学会了再补上)。

在学习 CDQ 分治之前,你应该已经比较熟练的掌握分治算法。

三维偏序

三维偏序一般可以作为 CDQ 分治的模板题来做,其形式为:

n 个元素,第 i 个元素有 a_i,b_i,c_i 三个属性,设 f(i) 表示满足 a_j \leq a_i b_j \leq b_i c_j \leq c_i j \ne i j 的数量。 对于 d \in [0, n) ,求 f(i) = d 的数量。

此内容来自P3810 【模板】三维偏序(陌上花开),感觉这个问题问的很绕,我们稍作修改,只询问有多少对 (i,j) 满足 a_j \leq a_i b_j \leq b_i c_j \leq c_i j \ne i ,接下来的内容将围绕此题展开。

三维偏序对现在的我们来说有点太难了,我们先考虑低一点的——一维偏序。

显然一维偏序直接 sort 一下就可以了。

接下来看二维的。

我们参考一维的形式,先将这 n 个元素按照 a 为第一关键字,b 为第二关键字进行排序,这样的话第一维已经有序,只要满足 i\ge j,就是满足了 a_i\ge a_j

接下来思考如何处理第二维,由于第二维已经在第一维相同的时候排好序了,所以计算 i 的答案的时候可以直接遍历 j1i-1,因为 i>j,必定有 a_i\ge a_j,所以此时只要有满足 b_i\ge b_j,就出现了一组偏序,记录答案即可。

考虑到这样做的时间复杂度为 O(n^2),需要优化。

我们维护一个数组 s。并且同时维护其前缀和 sum_i=\sum\limits_{j=1}^{i}s_j

遍历 i1n,我们查询 sum_{a_i},然后将 s_{a_i} 加上 1,这样查询到的数就是前边比 b_i 小的 b_j 的个数。

于是使用树状数组进行维护,支持单点修改和区间查询即可。时间复杂度 O(n\log n)

看到这里其实可以发现逆序对其实就是一个属性为下标(即已经有序)的二维偏序,同样可以使用上述方法解决。

好了,说了这么多,直接进入正题——三维偏序。

先是跟二维偏序一样的套路,将这 n 个元素按照 a 为第一关键字,b 为第二关键字,c 为第三关键字进行排序。

接下来,就是 CDQ 分治分出 b 这一维,再用树状数组解决第三维了。

听起来是不是很简单?我们只需要关注如何使用 CDQ 分治解决第二维就可以了。

分治一个区间 [l,r] 过程是这样的:

首先,找到其中点 mid

ij 的关系可以分为以下三种:

  1. l\le j<i\le mid
  2. mid<j<i \le r
  3. l\le j\le mid< i\le r

对于前两种情况,可以看出其为区间 [l,r] 的一个子问题,直接分治递归求解即可。

我们主要考虑第三种情况,由于要保证 a_i\ge a_j,所以 i 一定在 j 的后边。

[l,mid][mid+1,r] 内的元素分别按照 b 为第一关键字,c 为第二关键字排序。

imid+1 遍历到 r,每走一步,就让 j 走到第一个满足 b_i<b_j 的位置,那么 jl 到现在位置的前一个位置这段区间内都是可能对 i 做贡献的,因为 j 是第一个不满足 b_i\ge b_j 的,又因为已经按照 b 排过序了,所以前边的 j 都满足 b_i\ge b_j。而且因为 i 在右半边,j 在左半边,永远满足 a_i\ge a_j,所以已经满足了前两维,那么第三维直接使用树状数组计算答案即可。

代码实现

void cdq(int l,int r) {
    if(l>=r)return;
    int mid=l+r>>1;
    cdq(l,mid);//情况1
    cdq(mid+1,r);//情况2
    int i,j=l;//两个指针
    sort(s+l,s+mid+1,cmpB);
    sort(s+mid+1,s+r+1,cmpB);//左半边和右半边分别排序,b为第一关键字,c为第二关键字
    for(i=mid+1;i<=r;i++) {
        while(s[i].b>=s[j].b&&j<=mid) {//找到所有满足条件的j
            bit.add(s[j].c,s[j].cnt);//满足条件直接在树状数组上加,s[j].cnt指的是这一元素出现多少次,下边解释
            j++;
        }
        ans+=bit.query(s[i].c);//计算答案
    }
    for(i=l;i<j;i++)bit.add(s[i].c,-s[i].cnt);//一定要清空
}

注意到代码里记录了相同元素的出现次数,这里说明一下:

CDQ 分治在遇到重复元素的时候无法计算重复元素之间的贡献,除非不计算重复元素的贡献,否则需要进行去重。

这里再贴一下模板题的代码:

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
const int M=1e6+5;
int n,k,m,tot;
struct BIT {
    int bit[N];
    inline int lowbit(int x){return x&-x;}
    inline void add(int x,int d){for(;x<=k;x+=lowbit(x))bit[x]+=d;}
    inline int query(int x){int res=0;for(;x;x-=lowbit(x))res+=bit[x];return res;}
}bit;
struct Node {
    int a,b,c,cnt,ans;
}ss[N],s[N];
int ans[N];
bool cmpA(Node a,Node b) {
    if(a.a==b.a) {
        return a.b==b.b?a.c<b.c:a.b<b.b;
    }
    return a.a<b.a;
}
bool cmpB(Node a,Node b) {
    return a.b==b.b?a.c<b.c:a.b<b.b;
}
void cdq(int l,int r) {
    if(l>=r)return;
    int mid=l+r>>1;
    cdq(l,mid);
    cdq(mid+1,r);
    int i,j=l;
    sort(s+l,s+mid+1,cmpB);
    sort(s+mid+1,s+r+1,cmpB);
    for(i=mid+1;i<=r;i++) {
        while(s[i].b>=s[j].b&&j<=mid) {
            bit.add(s[j].c,s[j].cnt);
            j++;
        }
        s[i].ans+=bit.query(s[i].c);
    }
    for(i=l;i<j;i++)bit.add(s[i].c,-s[i].cnt);
}
int main() {
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++) {
        scanf("%d%d%d",&ss[i].a,&ss[i].b,&ss[i].c);
    }
    sort(ss+1,ss+n+1,cmpA);
    for(int i=1;i<=n;i++) {
        tot++;
        if(ss[i].a!=ss[i+1].a||ss[i].b!=ss[i+1].b||ss[i].c!=ss[i+1].c) {
            s[++m]=ss[i];
            s[m].cnt=tot;
            tot=0;
        }
    }
    cdq(1,m);
    for(int i=1;i<=n;i++)
        ans[s[i].cnt+s[i].ans-1]+=s[i].cnt;//这里实际上就是计算重复元素的贡献
    for(int i=0;i<n;i++)
        printf("%d\n",ans[i]);
    return 0;
}