CDQ 分治详解
sjr3065335594 · · 算法·理论
CDQ 分治
CDQ 分治是一种离线的分治算法,和 dp 一样,是一种解决问题的思想,主要思想为用一个子问题来计算另一个子问题。CDQ 分治可以套很多层,每多一层时间复杂度就要乘上一个
在学习 CDQ 分治之前,你应该已经比较熟练的掌握分治算法。
三维偏序
三维偏序一般可以作为 CDQ 分治的模板题来做,其形式为:
有
此内容来自P3810 【模板】三维偏序(陌上花开),感觉这个问题问的很绕,我们稍作修改,只询问有多少对
三维偏序对现在的我们来说有点太难了,我们先考虑低一点的——一维偏序。
显然一维偏序直接 sort 一下就可以了。
接下来看二维的。
我们参考一维的形式,先将这
接下来思考如何处理第二维,由于第二维已经在第一维相同的时候排好序了,所以计算
考虑到这样做的时间复杂度为
我们维护一个数组
遍历
于是使用树状数组进行维护,支持单点修改和区间查询即可。时间复杂度
看到这里其实可以发现逆序对其实就是一个属性为下标(即已经有序)的二维偏序,同样可以使用上述方法解决。
好了,说了这么多,直接进入正题——三维偏序。
先是跟二维偏序一样的套路,将这
接下来,就是 CDQ 分治分出
听起来是不是很简单?我们只需要关注如何使用 CDQ 分治解决第二维就可以了。
分治一个区间
首先,找到其中点
则
-
l\le j<i\le mid -
mid<j<i \le r -
l\le j\le mid< i\le r
对于前两种情况,可以看出其为区间
我们主要考虑第三种情况,由于要保证
将
令
代码实现
void cdq(int l,int r) {
if(l>=r)return;
int mid=l+r>>1;
cdq(l,mid);//情况1
cdq(mid+1,r);//情况2
int i,j=l;//两个指针
sort(s+l,s+mid+1,cmpB);
sort(s+mid+1,s+r+1,cmpB);//左半边和右半边分别排序,b为第一关键字,c为第二关键字
for(i=mid+1;i<=r;i++) {
while(s[i].b>=s[j].b&&j<=mid) {//找到所有满足条件的j
bit.add(s[j].c,s[j].cnt);//满足条件直接在树状数组上加,s[j].cnt指的是这一元素出现多少次,下边解释
j++;
}
ans+=bit.query(s[i].c);//计算答案
}
for(i=l;i<j;i++)bit.add(s[i].c,-s[i].cnt);//一定要清空
}
注意到代码里记录了相同元素的出现次数,这里说明一下:
CDQ 分治在遇到重复元素的时候无法计算重复元素之间的贡献,除非不计算重复元素的贡献,否则需要进行去重。
这里再贴一下模板题的代码:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
const int M=1e6+5;
int n,k,m,tot;
struct BIT {
int bit[N];
inline int lowbit(int x){return x&-x;}
inline void add(int x,int d){for(;x<=k;x+=lowbit(x))bit[x]+=d;}
inline int query(int x){int res=0;for(;x;x-=lowbit(x))res+=bit[x];return res;}
}bit;
struct Node {
int a,b,c,cnt,ans;
}ss[N],s[N];
int ans[N];
bool cmpA(Node a,Node b) {
if(a.a==b.a) {
return a.b==b.b?a.c<b.c:a.b<b.b;
}
return a.a<b.a;
}
bool cmpB(Node a,Node b) {
return a.b==b.b?a.c<b.c:a.b<b.b;
}
void cdq(int l,int r) {
if(l>=r)return;
int mid=l+r>>1;
cdq(l,mid);
cdq(mid+1,r);
int i,j=l;
sort(s+l,s+mid+1,cmpB);
sort(s+mid+1,s+r+1,cmpB);
for(i=mid+1;i<=r;i++) {
while(s[i].b>=s[j].b&&j<=mid) {
bit.add(s[j].c,s[j].cnt);
j++;
}
s[i].ans+=bit.query(s[i].c);
}
for(i=l;i<j;i++)bit.add(s[i].c,-s[i].cnt);
}
int main() {
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++) {
scanf("%d%d%d",&ss[i].a,&ss[i].b,&ss[i].c);
}
sort(ss+1,ss+n+1,cmpA);
for(int i=1;i<=n;i++) {
tot++;
if(ss[i].a!=ss[i+1].a||ss[i].b!=ss[i+1].b||ss[i].c!=ss[i+1].c) {
s[++m]=ss[i];
s[m].cnt=tot;
tot=0;
}
}
cdq(1,m);
for(int i=1;i<=n;i++)
ans[s[i].cnt+s[i].ans-1]+=s[i].cnt;//这里实际上就是计算重复元素的贡献
for(int i=0;i<n;i++)
printf("%d\n",ans[i]);
return 0;
}