K-D Tree从入门到精通
Jasmine_Aura · · 个人记录
K-D Tree 从入门到精通
本教程由 Jasmine_Aura 独立编写。
我的洛谷
K-D Tree基础
什么是K-D Tree?
k-D Tree(KDT , k-Dimension Tree) 是一种可以 高效处理
k-D Tree具有二叉搜索树的形态,其每个节点都对应
在题目中,一般
节点信息储存
我们通常用一个结构体储存k-D Tree所有节点的信息
//假设维护的是K维的空间信息
struct node
{
int d[K];//d[i]表示第i+1维的坐标(因为下标从0开始)
int mx[K],mn[K];//mx[i]和mn[i]表示该节点所管辖的第i+1维坐标极值,比较关键,后面会讲
int ls,rs;//左右儿子
int val;//节点权值,可能会用到
int sum;//节点子树的权值和,可能会用到
}t[MAX],ori[MAX];//t为k-D Tree所使用的数组
建树
k-D Tree常见的建树方式有两种:交替建树,方差建树。
交替建树代码实现简单,容易记忆,而方差建树较为复杂,用的较少,故这里只介绍交替建树的方法。
交替建树
交替建树有以下几个步骤:
- 将当前点集的所有点按第
d 维排序,取出中位数。 - 将取出的中位数作为当前点集的根节点,剩余的点集分别作为该点的左子树和右子树。
- 重复以上两个步骤,每次
d 要改成d+1 ,如果d+1\gt k ,则d=1 (其中k 为所有点的维度),直到所有树上的所有点都确定。
有点抽象?不怕!举个例子就懂了!
给出
首先按第
取出中位数
对于剩下的两个点集,我们继续建树,此时我们要对第
然后得到两个点集,
分别取出中位数
这样,我们就通过 交替建树 的方式得到了一颗层数为
这一过程可以用一个平面直角坐标系来表示:
找中位数是个比较棘手的事情,我们当然可以使用 sort() 排序,然后找到位置为 mid 的元素,但是这样的时间复杂度是
我们其实还可以使用这样一个函数 nth_element() ,将 l 和 r 之间的数按照比较规则 cmp 排序后,位置为 mid 的元素就是中位数,像这样 nth_element(ori+l,ori+mid,ori+r+1,cmp)。这样的时间复杂度是
k-D Tree建树整体的复杂度是
代码实现:
int K;//排序时的维度
bool cmp(node a,node b)
{
return a.d[K]<b.d[K];//按第K维排序
}
int build(int l,int r,int k)//l,r分别是集合的边界,k是排序时的维度
{
if(l>r)return 0;
int mid=(l+r)>>1;
K=k;
nth_element(ori+l,ori+mid,ori+r+1,cmp);//ori为点集
t[mid]=ori[mid];//t为k-D Tree所使用的数组
t[mid].ls=build(l,mid-1,k^1);//k^1,切换维度,交替建树
t[mid].rs=build(mid+1,r,k^1);
update(mid);//合并子树信息,后面会讲
return mid;//返回节点编号
}
子树信息合并
每个节点都需要维护 mx[i] 和 mn[i] , 即该节点所管辖的第
对于 (mx[0],mx[1]),(mn[0],mx[1]),(mx[0],mn[1]),(mn[0],mn[1]) 这四个点可以看作该节点所管辖矩形的四个顶点。
合并起来比较类似于线段树,看代码:
void update(int p)
{
int ls=t[p].ls;
int rs=t[p].rs;
t[p].sum=t[ls].sum+t[rs].sum+t[p].val;//维护节点权值和,可能会用到
for(int i=0;i<K;i++)//K为所有点的维度,例如平面上的点的K就是2
{
t[p].mx[i]=t[p].mn[i]=t[p].d[i];//初始化
if(ls)
{
t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
}
if(rs)
{
t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
}
//这里比较容易写错,需要细心检查
}
}
mx[i] 和 mn[i] 非常关键,合理使用能极大优化查询时的复杂度。
下面放一个只需要建树和简单的查询的小k-D Tree模板题
洛谷 P4475 巧克力王国
给出
请思考后再看题解哦!!!
题解:
每个二元组可以抽象成平面上的点,使用k-D Tree来维护,建树就不用说了,把权值和维护上就行,现在主要来说查询。
我们已经维护了 mx[i] 和 mn[i] ,如果当前节点管辖的矩形的四个顶点都满足
int query(int x)
{
int tot=0;
tot+=check(t[x].mx[0],t[x].mx[1]);
tot+=check(t[x].mn[0],t[x].mn[1]);
tot+=check(t[x].mx[0],t[x].mn[1]);
tot+=check(t[x].mn[0],t[x].mx[1]);
if(tot==4)return t[x].sum;//当前节点管辖的矩形满足 ax+by<c
if(tot==0)return 0;//都不满足
int ans=0;//ans为答案
if(check(t[x].d[0],t[x].d[1]))ans+=t[x].val;//该节点满足 ax+by<c,先加上这个节点的权值
if(t[x].ls)ans+=query(t[x].ls);//递归左子树
if(t[x].rs)ans+=query(t[x].rs);//递归右子树
return ans;
}
完整代码(不要复制哦)
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAX=5e4+4;
int a,b,c,K;
template<typename T>
void read(T &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=x*10+(ch^48);
ch=getchar();
}
x*=f;
}
struct node
{
int d[5];
int mx[5],mn[5];
int val,sum,ls,rs;
}ori[MAX],t[MAX];
bool cmp(node a,node b)
{
return a.d[K]<b.d[K];
}
bool check(int x,int y)
{
return a*x+b*y<c;
}
void update(int x)
{
int ls=t[x].ls;
int rs=t[x].rs;
for(int i=0;i<=1;i++)
{
t[x].mx[i]=t[x].mn[i]=t[x].d[i];
if(ls)
{
t[x].mx[i]=max(t[x].mx[i],t[ls].mx[i]);
t[x].mn[i]=min(t[x].mn[i],t[ls].mn[i]);
}
if(rs)
{
t[x].mx[i]=max(t[x].mx[i],t[rs].mx[i]);
t[x].mn[i]=min(t[x].mn[i],t[rs].mn[i]);
}
}
t[x].sum=t[ls].sum+t[rs].sum+t[x].val;
}
int build(int l,int r,int k)
{
if(l>r)return 0;
int mid=(l+r)>>1;
K=k;
nth_element(ori+l,ori+mid,ori+r+1,cmp);
t[mid]=ori[mid];
if(l<mid)t[mid].ls=build(l,mid-1,k^1);
if(r>mid)t[mid].rs=build(mid+1,r,k^1);
update(mid);
return mid;
}
int query(int x)
{
int tot=0;
tot+=check(t[x].mx[0],t[x].mx[1]);
tot+=check(t[x].mn[0],t[x].mn[1]);
tot+=check(t[x].mx[0],t[x].mn[1]);
tot+=check(t[x].mn[0],t[x].mx[1]);
if(tot==4)return t[x].sum;
if(tot==0)return 0;
int ans=0;
if(check(t[x].d[0],t[x].d[1]))ans+=t[x].val;
if(t[x].ls)ans+=query(t[x].ls);
if(t[x].rs)ans+=query(t[x].rs);
return ans;
}
signed main()
{
int n,m;
read(n),read(m);
for(int i=1;i<=n;i++)
{
read(ori[i].d[0]),read(ori[i].d[1]),read(ori[i].val);
}
int root=build(1,n,0);
for(int i=1;i<=m;i++)
{
read(a),read(b),read(c);
cout<<query(root)<<endl;
}
return 0;
}
插入重构
插入
上面只是简简单单的静态的建树,如果我要插入点,你该怎么办?
首先,由于k-D Tree具有二叉搜索树的形态,插入的时候只需要和 当前节点的对应维度坐标 作比较即可,如果小于等于当前节点的第
维度的切换和交替建树一样,
例如,我要将
从
到达
到达
这样一来,插入操作就成功了,吗?
如果我疯狂插入几次极端的点,k-D Tree就有可能变成这样,层数就不再是
void insert(int &p,node temp,int k)//k为比较维度
{
if(!p)//k-D Tree中没这个点,直接插入
{
p=newnode();//给个新的编号
t[p]=temp;
t[p].rs=0,t[p].ls=0;//这步非常重要,不然左右儿子信息会混乱
update(p);//上传信息
return;
}
if(temp.d[k]<=t[p].d[k])insert(t[p].ls,temp,k^1);//插进左子树
else insert(t[p].rs,temp,k^1);//插进右子树
update(p);//上传信息
check(p,k);//判断是否失衡
//注意要先update()再check()
}
怎么判断是否需要重构?类似 替罪羊树(没学过也没关系,我也没学过),我们引入一个 平衡因子
在 update() 函数中,我们需要额外维护子树的大小。
暴力重构
类似静态建树,我们把失衡的子树拍成一个序列,重新对当前子树建树。
bool cmp(int a,int b)//比较的函数有所改变
{
return t[a].d[K]<t[b].d[K];
}
void pia(int p)
{
if(!p)return;
pia(t[p].ls);
g[++tot]=p;//把当前节点添到序列中
pia(t[p].rs);
}
void update(int p)
{
...
t[p].sz=t[ls].sz+t[rs].sz+1;
}
int rebuild(int l,int r,int k)//类似静态建树 ,参数有所改变
{
if(l>r)return 0;
int mid=(l+r)>>1;
K=k;
nth_element(g+l,g+mid,g+r+1,cmp);//注意这里传的是拍出的序列
t[g[mid]].ls=rebuild(l,mid-1,k^1);
t[g[mid]].rs=rebuild(mid+1,r,k^1);
update(g[mid]);
return g[mid];
}
void check(int &p,int k)
{
int ls=t[p].ls;
int rs=t[p].rs;
if(A*t[p].sz<max(t[ls].sz,t[rs].sz))//A为平衡因子
{
tot=0;//清空序列
pia(p);//把以p为根的子树拍成序列,方便重构
p=rebuild(1,tot,k);//暴力重构,并更新p点编号
}
}
这么难的东西都学会了,那么再来做一道例题罢!
洛谷 P4148 简单题
给一个
1 x y A 将格子 x,y 里的数字加上
2 x1 y1 x2 y2 输出
强制在线,内存限制
非常酷的题,强制在线卡掉了 CDQ分治,内存限制
观察到
插入操作解决了,现在考虑如何查询。
如果当前节点维护的区域完全被询问的矩形包含,那说明它子树中的所有节点都被包含,直接返回权值和。
如果当前节点维护的区域和询问的矩形完全没有交集,那说明它子树中的所有节点都不在矩形中,直接返回
如果当前节点维护的点在矩形内,给答案加上这个点的权值。
剩下的就可以直接递归该节点的左右子树,累加答案即可。
int query(int p)
{
if((!p)|| t[p].mn[0]>X2 || t[p].mx[0]<X1 || t[p].mn[1]>Y2 || t[p].mx[1]<Y1)return 0;//没交集
if(t[p].mn[0]>=X1 && t[p].mx[0]<=X2 && t[p].mn[1]>=Y1 && t[p].mx[1]<=Y2)return t[p].sum;//被完全包含
int ans=0;
if(t[p].d[0]>=X1 && t[p].d[0]<=X2 && t[p].d[1]>=Y1 && t[p].d[1]<=Y2)ans+=t[p].val;//维护的点在矩形中
return query(t[p].ls)+query(t[p].rs)+ans;//递归左右子树
}
完整代码 (不要复制哦)
#include<bits/stdc++.h>
using namespace std;
const int MAX=2e5+5;
int K=0,last_ans=0,N=0,cur=0,tot=0;
int X1,X2,Y1,Y2;
double A=0.75;
int g[MAX];
template<typename T>
void read(T &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=x*10+(ch^48);
ch=getchar();
}
x*=f;
}
struct node
{
int d[2];
int mx[2],mn[2];
int sum,val,sz;
int ls,rs;
}t[MAX];
int newnode()
{
return ++cur;
}
bool cmp(int a,int b)
{
return t[a].d[K]<t[b].d[K];
}
void update(int x)
{
int ls=t[x].ls;
int rs=t[x].rs;
t[x].sz=t[ls].sz+t[rs].sz+1;
for(int i=0;i<=1;i++)
{
t[x].mx[i]=t[x].mn[i]=t[x].d[i];
if(ls)
{
t[x].mx[i]=max(t[x].mx[i],t[ls].mx[i]);
t[x].mn[i]=min(t[x].mn[i],t[ls].mn[i]);
}
if(rs)
{
t[x].mx[i]=max(t[x].mx[i],t[rs].mx[i]);
t[x].mn[i]=min(t[x].mn[i],t[rs].mn[i]);
}
}
t[x].sum=t[ls].sum+t[rs].sum+t[x].val;
}
void pia(int p)
{
if(!p)return;
pia(t[p].ls);
g[++tot]=p;
pia(t[p].rs);
}
int rebuild(int l,int r,int k)
{
if(l>r)return 0;
int mid=(l+r)>>1;
K=k;
nth_element(g+l,g+mid,g+r+1,cmp);
t[g[mid]].ls=rebuild(l,mid-1,k^1);
t[g[mid]].rs=rebuild(mid+1,r,k^1);
update(g[mid]);
return g[mid];
}
void check(int &p,int k)
{
int ls=t[p].ls;
int rs=t[p].rs;
if(A*t[p].sz<max(t[ls].sz,t[rs].sz))
{
tot=0;
pia(p);
p=rebuild(1,tot,k);
}
}
void insert(int &p,node temp,int k)
{
if(!p)
{
p=newnode();
t[p]=temp;
t[p].rs=0,t[p].ls=0;
update(p);
return;
}
if(temp.d[k]<=t[p].d[k])insert(t[p].ls,temp,k^1);
else insert(t[p].rs,temp,k^1);
update(p);
check(p,k);
}
int query(int p)
{
if((!p)|| t[p].mn[0]>X2 || t[p].mx[0]<X1 || t[p].mn[1]>Y2 || t[p].mx[1]<Y1)return 0;
if(t[p].mn[0]>=X1 && t[p].mx[0]<=X2 && t[p].mn[1]>=Y1 && t[p].mx[1]<=Y2)return t[p].sum;
int ans=0;
if(t[p].d[0]>=X1 && t[p].d[0]<=X2 && t[p].d[1]>=Y1 && t[p].d[1]<=Y2)ans+=t[p].val;
return query(t[p].ls)+query(t[p].rs)+ans;
}
int main()
{
int n,root=0;
read(n);
while(1)
{
int opt;
read(opt);
if(opt==1)
{
node temp;
read(temp.d[0]);
read(temp.d[1]);
read(temp.val);
temp.d[0]^=last_ans,temp.d[1]^=last_ans,temp.val^=last_ans;
insert(root,temp,0);
}
if(opt==2)
{
read(X1),read(Y1),read(X2),read(Y2);
X1^=last_ans,X2^=last_ans,Y1^=last_ans,Y2^=last_ans;
if(X1>X2)swap(X1,X2);
if(Y1>Y2)swap(Y1,Y2);
last_ans=query(root);
cout<<last_ans<<endl;
}
if(opt==3)break;
}
return 0;
}
查询
最经典的一个询问就是最近点查询,例如已经给了你
关于这个距离,欧几里得距离 和 曼哈顿距离 的查询方法都是一样的,只是计算上有所不同,具体的可以例题上解释。
我们回归正题,这个最近点该怎么查询呢?很显然,直接遍历
剪枝
假如我们当前搜到了k-D Tree上的点
接下来递归左子树和右子树,为了降低我们的时间复杂度,我们可以剪枝,具体是这样的。
- 分别计算出点
Q 到点p 的左子树和右子树所管辖矩形的距离,记为disls 和disrs 。 - 比较
disls 和disrs ,如果disls\lt disrs ,则先递归左子树,否则先递归右子树。
void query(int p)
{
ans=min(ans,dis(p));
int disls=INF;
int disrs=INF;
if(t[p].ls)disls=dismatrix(t[p].ls);//计算查询点到左子树管辖矩形的距离
if(t[p].rs)disrs=dismatrix(t[p].rs);//计算查询点到右子树管辖矩形的距离
if(disls<disrs)
{
if(disls<ans)query(t[p].ls);//先递归左子树
if(disrs<ans)query(t[p].rs);
}
else
{
if(disrs<ans)query(t[p].rs);//先递归右子树
if(disls<ans)query(t[p].ls);
}
}
可以发现,如果我们先递归左子树,更新了答案以后,如果右子树管辖矩形到查询点的距离大于已更新的答案,就直接跳过了,这样我们就完美地做到了剪枝。
那么,来看一道例题吧!
洛谷 P4169 [Violet] 天使玩偶/SJY摆棋子
在二维平面上,给出
1 x y 添加一个点
2 x y 查询所有点中,距离
本题中,距离指曼哈顿距离,即
本题有插入操作,不能静态建树,需要考虑插入新点后是否重构。
现在考虑如何查询,其实就是上面那个过程,主要难点在如何求出某个点到矩形的曼哈顿距离。
看这张美丽的图,假设
int dismatrix(int p)
{
//(X,Y)为查询点
int res=0;
res+=max(0,X-t[p].mx[0])+max(0,t[p].mn[0]-X);
res+=max(0,Y-t[p].mx[1])+max(0,t[p].mn[1]-Y);
return res;
}
对于更高维的,我们给出这样一个公式
由于没有插入操作,静态建树即可。
假设我们现在只查询最小距离,直接枚举
最大距离其实也很容易解决,和查询最小距离反着来即可,注意这时查询点到矩形的距离要取更远的那一个,否则会漏掉答案。
非常好,那就简单了,我们直接枚举
代码实现起来比较麻烦,细节较多。
放一下代码 (不要复制哦)
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAX=1e5+5;
const int INF=1e18;
template<typename T>
void read(T &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=x*10+(ch^48);
ch=getchar();
}
x*=f;
}
int K=0,maxdis,mindis,X,Y;
int ABS(int x)
{
return (x<0?(-x):x);
}
struct node
{
int d[2],mx[2],mn[2];
int ls,rs;
bool del=0;
}t[MAX],ori[MAX];
bool cmp(node a,node b)
{
return a.d[K]<b.d[K];
}
void update(int p)
{
int ls=t[p].ls;
int rs=t[p].rs;
for(int i=0;i<=1;i++)
{
t[p].mx[i]=t[p].mn[i]=t[p].d[i];
if(ls)
{
t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
}
if(rs)
{
t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
}
}
}
int build(int l,int r,int k)
{
if(l>r)return 0;
K=k;
int mid=(l+r)>>1;
nth_element(ori+l,ori+mid,ori+r+1,cmp);
t[mid]=ori[mid];
t[mid].ls=build(l,mid-1,k^1);
t[mid].rs=build(mid+1,r,k^1);
update(mid);
return mid;
}
int dis(int p)
{
return ABS(t[p].d[0]-X)+ABS(t[p].d[1]-Y);
}
int dismatrixmin(int p)
{
int res=0;
res+=max((long long)0,X-t[p].mx[0])+max((long long)0,t[p].mn[0]-X);
res+=max((long long)0,Y-t[p].mx[1])+max((long long)0,t[p].mn[1]-Y);
return res;
}
int dismatrixmax(int p)
{
int res=0;
res+=max((long long)0,t[p].mx[0]-X)+max((long long)0,X-t[p].mn[0]);
res+=max((long long)0,t[p].mx[1]-Y)+max((long long)0,Y-t[p].mn[1]);
return res;
}
void querymin(int p)
{
if(!t[p].del)mindis=min(mindis,dis(p));
int dl=INF,dr=INF;
if(t[p].ls)dl=dismatrixmin(t[p].ls);
if(t[p].rs)dr=dismatrixmin(t[p].rs);
if(dl<dr)
{
if(dl<mindis)querymin(t[p].ls);
if(dr<mindis)querymin(t[p].rs);
}
else
{
if(dr<mindis)querymin(t[p].rs);
if(dl<mindis)querymin(t[p].ls);
}
}
void querymax(int p)
{
if(!t[p].del)maxdis=max(maxdis,dis(p));
int dl=-INF,dr=-INF;
if(t[p].ls)dl=dismatrixmax(t[p].ls);
if(t[p].rs)dr=dismatrixmax(t[p].rs);
if(dl>dr)
{
if(dl>maxdis)querymax(t[p].ls);
if(dr>maxdis)querymax(t[p].rs);
}
else
{
if(dr>maxdis)querymax(t[p].rs);
if(dl>maxdis)querymax(t[p].ls);
}
}
signed main()
{
int n,root=0,ans=INF;
read(n);
for(int i=1;i<=n;i++)read(ori[i].d[0]),read(ori[i].d[1]);
root=build(1,n,0);
for(int i=1;i<=n;i++)
{
t[i].del=1;
X=t[i].d[0],Y=t[i].d[1];
maxdis=-INF,mindis=INF;
querymin(root),querymax(root);
/*cout<<t[i].mn[0]<<" "<<t[i].mn[1]<<" "<<t[i].mx[0]<<" "<<t[i].mx[1]<<endl;
cout<<"ls"<<t[i].ls<<endl;
cout<<"rs"<<t[i].rs<<endl;*/
ans=min(ans,maxdis-mindis);
t[i].del=0;
}
cout<<ans<<endl;
return 0;
}
k 远距离查询
洛谷 P2093 [国家集训队] JZPFAR
二维平面上有
本题中,距离指欧几里得距离,即
是个麻烦题。首先,为了方便,我们在计算距离时可以不用开方,不影响答案。
由于没有插入操作,静态建树即可。
查询最大距离就不用说了,之前提过了,现在考虑怎么统计第
我们可以开一个 小根堆,先插入
具体过程如下:
struct ans
{
int dis,id;
};
bool operator <(ans a,ans b)
{
if(a.dis==b.dis)return a.id<b.id;
return a.dis>b.dis;
}
priority_queue<ans>q;
void query(int p)
{
if(!p)return;
int dis=getdis(t[p].d[0],t[p].d[1],X,Y);//获取当前点到查询点的距离
if(dis>q.top().dis || (dis==q.top().dis && t[p].id<q.top().id))//距离更远
{
q.pop();
q.push({dis,t[p].id});//把更远的丢进堆
}
//剪枝优化
int dl=-INF,dr=-INF;
if(t[p].ls)dl=dismatrix(t[p].ls);
if(t[p].rs)dr=dismatrix(t[p].rs);
if(dl>dr)
{
if(dl>=q.top().dis)query(t[p].ls);
if(dr>=q.top().dis)query(t[p].rs);
}
else
{
if(dr>=q.top().dis)query(t[p].rs);
if(dl>=q.top().dis)query(t[p].ls);
}
}
完整代码 (不要复制哦)
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAX=1e6+5;
const int INF=1e18;
template<typename T>
void read(T &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=x*10+(ch^48);
ch=getchar();
}
x*=f;
}
struct node
{
int d[2],mx[2],mn[2];
int ls,rs;
int id;
}t[MAX],ori[MAX];
struct ans
{
int dis,id;
};
bool operator <(ans a,ans b)
{
if(a.dis==b.dis)return a.id<b.id;
return a.dis>b.dis;
}
priority_queue<ans>q;
int K=0,X,Y;
bool cmp(node a,node b)
{
return a.d[K]<b.d[K];
}
void update(int p)
{
int ls=t[p].ls;
int rs=t[p].rs;
for(int i=0;i<=1;i++)
{
t[p].mx[i]=t[p].mn[i]=t[p].d[i];
if(ls)
{
t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
}
if(rs)
{
t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
}
}
}
int build(int l,int r,int k)
{
if(l>r)return 0;
K=k;
int mid=(l+r)>>1;
nth_element(ori+l,ori+mid,ori+r+1,cmp);
t[mid]=ori[mid];
t[mid].ls=build(l,mid-1,k^1);
t[mid].rs=build(mid+1,r,k^1);
update(mid);
return mid;
}
int getdis(int x1,int y1,int x2,int y2)
{
int k1=(x1-x2)*(x1-x2);
int k2=(y1-y2)*(y1-y2);
return k1+k2;
}
int dismatrix(int p)
{
int dis=-INF;
dis=max(dis,getdis(t[p].mn[0],t[p].mn[1],X,Y));
dis=max(dis,getdis(t[p].mx[0],t[p].mn[1],X,Y));
dis=max(dis,getdis(t[p].mx[0],t[p].mx[1],X,Y));
dis=max(dis,getdis(t[p].mn[0],t[p].mx[1],X,Y));
return dis;
}
void query(int p)
{
if(!p)return;
int dis=getdis(t[p].d[0],t[p].d[1],X,Y);
if(dis>q.top().dis || (dis==q.top().dis && t[p].id<q.top().id))
{
q.pop();
q.push({dis,t[p].id});
}
int dl=-INF,dr=-INF;
if(t[p].ls)dl=dismatrix(t[p].ls);
if(t[p].rs)dr=dismatrix(t[p].rs);
if(dl>dr)
{
if(dl>=q.top().dis)query(t[p].ls);
if(dr>=q.top().dis)query(t[p].rs);
}
else
{
if(dr>=q.top().dis)query(t[p].rs);
if(dl>=q.top().dis)query(t[p].ls);
}
}
signed main()
{
int n,m,k;
read(n);
for(int i=1;i<=n;i++){read(ori[i].d[0]),read(ori[i].d[1]);ori[i].id=i;}
int root=build(1,n,0);
read(m);
for(int i=1;i<=m;i++)
{
read(X),read(Y),read(k);
while(!q.empty())q.pop();
for(int i=1;i<=k;i++)q.push({-INF,INF});
query(root);
cout<<q.top().id<<endl;
}
return 0;
}
圆的相交问题
洛谷 P4631 [APIO2018] 选圆圈
在二维平面上,有
- 找到这些圆中半径最大的。如果有多个半径最大的圆,选择编号最小的。记为
c_i 。 - 删除
c_i 及与其有交集的所有圆。两个圆有交集当且仅当平面上存在一个点,这个点同时在这两个圆的圆周上或圆内。 - 重复上面两个步骤直到所有的圆都被删除。
当
样例:
11
9 9 2
13 2 1
11 8 2
3 3 2
3 12 1
12 14 1
9 8 5
2 8 2
5 2 1
14 4 2
14 14 1
7 2 7 4 5 6 7 7 4 7 6
样例解释:
有圆,不太好处理,看起来无从下手?
我们不妨把每个圆看成矩形(即边长为
那么怎么储存每个圆的信息呢?还是k-D Tree,这时我们每个节点维护的信息比较多。
struct circle
{
int d[2];//圆的坐标
int r,id;//半径和编号
}c[MAX];
struct node
{
circle dat;//该点代表的圆
int mx[2],mn[2];//近似矩形
int ls,rs;//左右儿子
}t[MAX];
矩形很好维护,只是四个顶点的坐标和之前有所不同,用圆的半径和坐标计算即可。
由于没有插入操作,静态建树即可。
接下来是比较关键的内容,如何查询哪些圆与当前圆相交。
首先是判断圆与圆是否有交集,比较好推,也是一个比较常见的结论。
对于两个圆
假设我们当前查询到了k-D Tree上的点
至于怎么判断两个矩形是否相交,条件较多,我们不妨判断两个矩形是否不相交:
- 如果矩形
A 最右端小于矩形B 最左端,则两个矩形不相交。 - 如果矩形
A 最左端大于矩形B 最右端,则两个矩形不相交。 - 如果矩形
A 最上端小于矩形B 最下端,则两个矩形不相交。 - 如果矩形
A 最下端大于矩形B 最上端,则两个矩形不相交。
完整代码 (不要复制哦)
#include<bits/stdc++.h>
using namespace std;
#define int long long
template<typename T>
void read(T &x)
{
x=0;
int f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=x*10+(ch^48);
ch=getchar();
}
x*=f;
}
const int MAX=1e6+5;
struct circle
{
int d[2];
int r,id;
}c[MAX];
struct node
{
circle dat;
int mx[2],mn[2];
int ls,rs;
}t[MAX];
int K=0;
int ans[MAX];
bool cmp(circle a,circle b)
{
return a.d[K]<b.d[K];
}
bool recmp(circle a,circle b)
{
if(a.r==b.r)return a.id<b.id;
return a.r>b.r;
}
void update(int p)
{
int ls=t[p].ls;
int rs=t[p].rs;
for(int i=0;i<=1;i++)
{
t[p].mn[i]=t[p].dat.d[i]-t[p].dat.r;
t[p].mx[i]=t[p].dat.d[i]+t[p].dat.r;
if(ls)
{
t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
}
if(rs)
{
t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
}
}
}
int build(int l,int r,int k)
{
if(l>r)return 0;
K=k;
int mid=(l+r)>>1;
nth_element(c+l,c+mid,c+r+1,cmp);
t[mid].dat=c[mid];
t[mid].ls=build(l,mid-1,k^1);
t[mid].rs=build(mid+1,r,k^1);
update(mid);
return mid;
}
bool intersect(circle a,circle b)
{
int x=a.d[0];
int y=a.d[1];
int r=a.r;
int sec1=(x-b.d[0])*(x-b.d[0]);
int sec2=(y-b.d[1])*(y-b.d[1]);
int sec3=(r+b.r)*(r+b.r);
if(sec1+sec2<=sec3)return 1;
return 0;
}
bool dismeet(int p,circle temp)
{
int L=temp.d[0]-temp.r;
int R=temp.d[0]+temp.r;
int U=temp.d[1]+temp.r;
int D=temp.d[1]-temp.r;
if(R<t[p].mn[0])return 1;
if(U<t[p].mn[1])return 1;
if(L>t[p].mx[0])return 1;
if(D>t[p].mx[1])return 1;
return 0;
}
void query(int p,circle temp)
{
if(dismeet(p,temp))return;
if(intersect(t[p].dat,temp) && !ans[t[p].dat.id])ans[t[p].dat.id]=temp.id;
int ls=t[p].ls;
int rs=t[p].rs;
if(ls)query(ls,temp);
if(rs)query(rs,temp);
}
signed main()
{
int n;
read(n);
for(int i=1;i<=n;i++)
{
read(c[i].d[0]),read(c[i].d[1]),read(c[i].r);
c[i].id=i;
}
int root=build(1,n,0);
sort(c+1,c+n+1,recmp);
for(int i=1;i<=n;i++)if(!ans[c[i].id])query(root,c[i]);
for(int i=1;i<=n;i++)cout<<ans[i]<<" ";
return 0;
}
总结
如果你想牢记这个算法,最好按照我的写法来,记忆起来非常方便,并且思路比较清晰。感谢观看