K-D Tree从入门到精通

· · 个人记录

K-D Tree 从入门到精通

本教程由 Jasmine_Aura 独立编写。

我的洛谷

K-D Tree基础

什么是K-D Tree?

k-D Tree(KDT , k-Dimension Tree) 是一种可以 高效处理 k 维空间信息 的数据结构,常用于 领域查询,最近点对查询 等操作。

k-D Tree具有二叉搜索树的形态,其每个节点都对应 k 维空间内的一个点。

在题目中,一般 k=2

节点信息储存

我们通常用一个结构体储存k-D Tree所有节点的信息

//假设维护的是K维的空间信息
struct node
{
    int d[K];//d[i]表示第i+1维的坐标(因为下标从0开始)
    int mx[K],mn[K];//mx[i]和mn[i]表示该节点所管辖的第i+1维坐标极值,比较关键,后面会讲
    int ls,rs;//左右儿子
    int val;//节点权值,可能会用到
    int sum;//节点子树的权值和,可能会用到
}t[MAX],ori[MAX];//t为k-D Tree所使用的数组

建树

k-D Tree常见的建树方式有两种:交替建树方差建树

交替建树代码实现简单,容易记忆,而方差建树较为复杂,用的较少,故这里只介绍交替建树的方法

交替建树

交替建树有以下几个步骤:

  1. 将当前点集的所有点按第 d 维排序,取出中位数。
  2. 将取出的中位数作为当前点集的根节点,剩余的点集分别作为该点的左子树和右子树。
  3. 重复以上两个步骤,每次 d 要改成 d+1,如果 d+1\gt k ,则 d=1(其中 k 为所有点的维度),直到所有树上的所有点都确定。

有点抽象?不怕!举个例子就懂了!

给出 72 维空间的点,分别为 (1,4),(2,5),(1,5),(2,3),(3,1),(4,1),(3,4)

首先按第 1 维排序,得到的点集是 (1,4),(1,5),(2,5),(2,3),(3,1),(3,4),(4,1)

取出中位数 (2,3),作为当前点集的根节点,剩余的点集分别作为左子树和右子树,如图:

对于剩下的两个点集,我们继续建树,此时我们要对第 2 维排序。

然后得到两个点集,(1,4),(1,5),(2,5)(3,1),(4,1),(3,4)

分别取出中位数 (1,5)(4,1) 作为两个点集的根节点,剩余的点集继续建树,最后得到的树如图所示:

这样,我们就通过 交替建树 的方式得到了一颗层数为 \log n 的 k-D Tree。

这一过程可以用一个平面直角坐标系来表示:

找中位数是个比较棘手的事情,我们当然可以使用 sort() 排序,然后找到位置为 mid 的元素,但是这样的时间复杂度是 O(n\log n)

我们其实还可以使用这样一个函数 nth_element() ,将 lr 之间的数按照比较规则 cmp 排序后,位置为 mid 的元素就是中位数,像这样 nth_element(ori+l,ori+mid,ori+r+1,cmp)。这样的时间复杂度是 O(\log n)

k-D Tree建树整体的复杂度是 O(n\log n)

代码实现:

int K;//排序时的维度
bool cmp(node a,node b)
{
    return a.d[K]<b.d[K];//按第K维排序
}
int build(int l,int r,int k)//l,r分别是集合的边界,k是排序时的维度
{
    if(l>r)return 0;
    int mid=(l+r)>>1;
    K=k;
    nth_element(ori+l,ori+mid,ori+r+1,cmp);//ori为点集
    t[mid]=ori[mid];//t为k-D Tree所使用的数组
    t[mid].ls=build(l,mid-1,k^1);//k^1,切换维度,交替建树
    t[mid].rs=build(mid+1,r,k^1);
    update(mid);//合并子树信息,后面会讲
    return mid;//返回节点编号
}

子树信息合并

每个节点都需要维护 mx[i]mn[i] , 即该节点所管辖的第 i+1 维坐标极值。

对于 2 维的 k-D Tree,(mx[0],mx[1]),(mn[0],mx[1]),(mx[0],mn[1]),(mn[0],mn[1]) 这四个点可以看作该节点所管辖矩形的四个顶点。

合并起来比较类似于线段树,看代码:

void update(int p)
{
    int ls=t[p].ls;
    int rs=t[p].rs;
    t[p].sum=t[ls].sum+t[rs].sum+t[p].val;//维护节点权值和,可能会用到
    for(int i=0;i<K;i++)//K为所有点的维度,例如平面上的点的K就是2
    {
        t[p].mx[i]=t[p].mn[i]=t[p].d[i];//初始化
        if(ls)
        {
            t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
        }
        if(rs)
        {
            t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
        }
        //这里比较容易写错,需要细心检查
    }
}

mx[i]mn[i] 非常关键,合理使用能极大优化查询时的复杂度。

下面放一个只需要建树和简单的查询的小k-D Tree模板题

洛谷 P4475 巧克力王国

给出 n 个二元组 (x,y) ,每个二元组有一个权值 h,现有 m 次询问,每次给出 a,b,c,询问所有满足 ax+by\lt c 的二元组的权值和。

请思考后再看题解哦!!!

题解:

每个二元组可以抽象成平面上的点,使用k-D Tree来维护,建树就不用说了,把权值和维护上就行,现在主要来说查询。

我们已经维护了 mx[i]mn[i] ,如果当前节点管辖的矩形的四个顶点都满足 ax+by\lt c,说明该节点子树中的所有节点都满足 ax+by\lt c,直接返回权值和即可;如果都不满足,说明子树中的所有节点都不满足,直接返回 0。否则就像线段树一样,递归左右子树统计答案。

int query(int x)
{
    int tot=0;
    tot+=check(t[x].mx[0],t[x].mx[1]);
    tot+=check(t[x].mn[0],t[x].mn[1]);
    tot+=check(t[x].mx[0],t[x].mn[1]);
    tot+=check(t[x].mn[0],t[x].mx[1]);
    if(tot==4)return t[x].sum;//当前节点管辖的矩形满足 ax+by<c
    if(tot==0)return 0;//都不满足
    int ans=0;//ans为答案
    if(check(t[x].d[0],t[x].d[1]))ans+=t[x].val;//该节点满足 ax+by<c,先加上这个节点的权值
    if(t[x].ls)ans+=query(t[x].ls);//递归左子树
    if(t[x].rs)ans+=query(t[x].rs);//递归右子树
    return ans;
}

完整代码(不要复制哦)

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAX=5e4+4;
int a,b,c,K;
template<typename T>
void read(T &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=x*10+(ch^48);
        ch=getchar();
    }
    x*=f;
}
struct node
{
    int d[5];
    int mx[5],mn[5];
    int val,sum,ls,rs;
}ori[MAX],t[MAX];
bool cmp(node a,node b)
{
    return a.d[K]<b.d[K];
}
bool check(int x,int y)
{
    return a*x+b*y<c;
}
void update(int x)
{
    int ls=t[x].ls;
    int rs=t[x].rs;
    for(int i=0;i<=1;i++)
    {
        t[x].mx[i]=t[x].mn[i]=t[x].d[i];
        if(ls)
        {
            t[x].mx[i]=max(t[x].mx[i],t[ls].mx[i]);
            t[x].mn[i]=min(t[x].mn[i],t[ls].mn[i]);
        }
        if(rs)
        {
            t[x].mx[i]=max(t[x].mx[i],t[rs].mx[i]);
            t[x].mn[i]=min(t[x].mn[i],t[rs].mn[i]);
        }
    }
    t[x].sum=t[ls].sum+t[rs].sum+t[x].val;
}
int build(int l,int r,int k)
{
    if(l>r)return 0;
    int mid=(l+r)>>1;
    K=k;
    nth_element(ori+l,ori+mid,ori+r+1,cmp);
    t[mid]=ori[mid];
    if(l<mid)t[mid].ls=build(l,mid-1,k^1);
    if(r>mid)t[mid].rs=build(mid+1,r,k^1);
    update(mid);
    return mid;
}
int query(int x)
{
    int tot=0;
    tot+=check(t[x].mx[0],t[x].mx[1]);
    tot+=check(t[x].mn[0],t[x].mn[1]);
    tot+=check(t[x].mx[0],t[x].mn[1]);
    tot+=check(t[x].mn[0],t[x].mx[1]);
    if(tot==4)return t[x].sum;
    if(tot==0)return 0;
    int ans=0;
    if(check(t[x].d[0],t[x].d[1]))ans+=t[x].val;
    if(t[x].ls)ans+=query(t[x].ls);
    if(t[x].rs)ans+=query(t[x].rs);
    return ans;
}
signed main()
{
    int n,m;
    read(n),read(m);
    for(int i=1;i<=n;i++)
    {
        read(ori[i].d[0]),read(ori[i].d[1]),read(ori[i].val);
    }
    int root=build(1,n,0);
    for(int i=1;i<=m;i++)
    {
        read(a),read(b),read(c);
        cout<<query(root)<<endl;
    }
    return 0;
}

插入重构

插入

上面只是简简单单的静态的建树,如果我要插入点,你该怎么办?

首先,由于k-D Tree具有二叉搜索树的形态,插入的时候只需要和 当前节点的对应维度坐标 作比较即可,如果小于等于当前节点的第 d 维,则走向左子树,否则走向右子树。

维度的切换和交替建树一样,1K 轮流来即可。

例如,我要将 (5,0) 插入这棵树中。

(2,3) 开始,比较第 1 维,5\gt 2,走向右子树。

到达 (4,1),比较第 2 维,0\lt 1,走向左子树。

到达 (3,1),比较第 1 维,5\gt 3,走向右子树。

这样一来,插入操作就成功了,吗?

如果我疯狂插入几次极端的点,k-D Tree就有可能变成这样,层数就不再是 \log n 了,这时候该怎么办呢?重构!

void insert(int &p,node temp,int k)//k为比较维度
{
    if(!p)//k-D Tree中没这个点,直接插入
    {
        p=newnode();//给个新的编号
        t[p]=temp;
        t[p].rs=0,t[p].ls=0;//这步非常重要,不然左右儿子信息会混乱
        update(p);//上传信息
        return;
    }
    if(temp.d[k]<=t[p].d[k])insert(t[p].ls,temp,k^1);//插进左子树
    else insert(t[p].rs,temp,k^1);//插进右子树
    update(p);//上传信息
    check(p,k);//判断是否失衡
    //注意要先update()再check()
}

怎么判断是否需要重构?类似 替罪羊树(没学过也没关系,我也没学过),我们引入一个 平衡因子 \alpha,如果该节点的左子树或者右子树的子树大小超过整颗子树大小的 \alpha 倍,即 sz[p]*\alpha<max(sz[ls],sz[rs]),我们就需要进行重构,通常 0.6\leq \alpha \leq0.9,一般取中间值 0.75,可以根据喜好自己调。

update() 函数中,我们需要额外维护子树的大小。

暴力重构

类似静态建树,我们把失衡的子树拍成一个序列,重新对当前子树建树。

bool cmp(int a,int b)//比较的函数有所改变
{
    return t[a].d[K]<t[b].d[K];
}
void pia(int p)
{
    if(!p)return;
    pia(t[p].ls);
    g[++tot]=p;//把当前节点添到序列中
    pia(t[p].rs);
}
void update(int p)
{
    ...
    t[p].sz=t[ls].sz+t[rs].sz+1;
}
int rebuild(int l,int r,int k)//类似静态建树 ,参数有所改变
{
    if(l>r)return 0;
    int mid=(l+r)>>1;
    K=k;
    nth_element(g+l,g+mid,g+r+1,cmp);//注意这里传的是拍出的序列
    t[g[mid]].ls=rebuild(l,mid-1,k^1);
    t[g[mid]].rs=rebuild(mid+1,r,k^1);
    update(g[mid]);
    return g[mid];
}
void check(int &p,int k)
{
    int ls=t[p].ls;
    int rs=t[p].rs;
    if(A*t[p].sz<max(t[ls].sz,t[rs].sz))//A为平衡因子
    {
        tot=0;//清空序列
        pia(p);//把以p为根的子树拍成序列,方便重构
        p=rebuild(1,tot,k);//暴力重构,并更新p点编号
    }
}

这么难的东西都学会了,那么再来做一道例题罢!

洛谷 P4148 简单题

给一个 N\times N 的棋盘,每个格子里有一个整数,初始时都为 0,现在需要维护两种操作:

1 x y A 将格子 xy 里的数字加上 A

2 x1 y1 x2 y2 输出 x1,y1,x2,y2 这个矩形内的数字和。

N\leq5\times10^5

强制在线,内存限制 20MB

非常酷的题,强制在线卡掉了 CDQ分治,内存限制 20MB 卡掉了树套树,还是老实用k-D Tree吧!

观察到 N 非常大,我们不可能把棋盘中每个点加到k-d Tree中。由于初始时每个点的值都是 0,我们考虑操作 1,给 (x,y) 上的点加上 A 可以看作在 (x,y) 插入一个权值为 A 的点。如果我给同一个位置加两次怎么办呢?不怕,我们可以把它看作两个不同的点,因为我们查询的是矩形内的数字和,把同一个位置拆开并不影响答案。

插入操作解决了,现在考虑如何查询。

如果当前节点维护的区域完全被询问的矩形包含,那说明它子树中的所有节点都被包含,直接返回权值和。

如果当前节点维护的区域和询问的矩形完全没有交集,那说明它子树中的所有节点都不在矩形中,直接返回 0

如果当前节点维护的点在矩形内,给答案加上这个点的权值。

剩下的就可以直接递归该节点的左右子树,累加答案即可。

int query(int p)
{
    if((!p)|| t[p].mn[0]>X2 || t[p].mx[0]<X1 || t[p].mn[1]>Y2 || t[p].mx[1]<Y1)return 0;//没交集
    if(t[p].mn[0]>=X1 && t[p].mx[0]<=X2 && t[p].mn[1]>=Y1 && t[p].mx[1]<=Y2)return t[p].sum;//被完全包含
    int ans=0;
    if(t[p].d[0]>=X1 && t[p].d[0]<=X2 && t[p].d[1]>=Y1 && t[p].d[1]<=Y2)ans+=t[p].val;//维护的点在矩形中
    return query(t[p].ls)+query(t[p].rs)+ans;//递归左右子树
}

完整代码 (不要复制哦)

#include<bits/stdc++.h>
using namespace std;
const int MAX=2e5+5;
int K=0,last_ans=0,N=0,cur=0,tot=0;
int X1,X2,Y1,Y2;
double A=0.75;
int g[MAX];
template<typename T>
void read(T &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=x*10+(ch^48);
        ch=getchar();
    }
    x*=f;
}
struct node
{
    int d[2];
    int mx[2],mn[2];
    int sum,val,sz;
    int ls,rs;
}t[MAX];
int newnode()
{
    return ++cur;
}
bool cmp(int a,int b)
{
    return t[a].d[K]<t[b].d[K];
}   
void update(int x)
{
    int ls=t[x].ls;
    int rs=t[x].rs;
    t[x].sz=t[ls].sz+t[rs].sz+1;
    for(int i=0;i<=1;i++)
    {
        t[x].mx[i]=t[x].mn[i]=t[x].d[i];
        if(ls)
        {
            t[x].mx[i]=max(t[x].mx[i],t[ls].mx[i]);
            t[x].mn[i]=min(t[x].mn[i],t[ls].mn[i]);
        }
        if(rs)
        {
            t[x].mx[i]=max(t[x].mx[i],t[rs].mx[i]);
            t[x].mn[i]=min(t[x].mn[i],t[rs].mn[i]);
        }
    }
    t[x].sum=t[ls].sum+t[rs].sum+t[x].val;
}
void pia(int p)
{
    if(!p)return;
    pia(t[p].ls);
    g[++tot]=p;
    pia(t[p].rs);
}
int rebuild(int l,int r,int k)
{
    if(l>r)return 0;
    int mid=(l+r)>>1;
    K=k;
    nth_element(g+l,g+mid,g+r+1,cmp);
    t[g[mid]].ls=rebuild(l,mid-1,k^1);
    t[g[mid]].rs=rebuild(mid+1,r,k^1);
    update(g[mid]);
    return g[mid];
}
void check(int &p,int k)
{
    int ls=t[p].ls;
    int rs=t[p].rs;
    if(A*t[p].sz<max(t[ls].sz,t[rs].sz))
    {
        tot=0;
        pia(p);
        p=rebuild(1,tot,k);
    }
}
void insert(int &p,node temp,int k)
{
    if(!p)
    {
        p=newnode();
        t[p]=temp;
        t[p].rs=0,t[p].ls=0;
        update(p);
        return;
    }
    if(temp.d[k]<=t[p].d[k])insert(t[p].ls,temp,k^1);
    else insert(t[p].rs,temp,k^1);
    update(p);
    check(p,k);
}
int query(int p)
{
    if((!p)|| t[p].mn[0]>X2 || t[p].mx[0]<X1 || t[p].mn[1]>Y2 || t[p].mx[1]<Y1)return 0;
    if(t[p].mn[0]>=X1 && t[p].mx[0]<=X2 && t[p].mn[1]>=Y1 && t[p].mx[1]<=Y2)return t[p].sum;
    int ans=0;
    if(t[p].d[0]>=X1 && t[p].d[0]<=X2 && t[p].d[1]>=Y1 && t[p].d[1]<=Y2)ans+=t[p].val;
    return query(t[p].ls)+query(t[p].rs)+ans;
}
int main()
{
    int n,root=0;
    read(n);
    while(1)
    {
        int opt;
        read(opt);
        if(opt==1)
        {
            node temp;
            read(temp.d[0]);
            read(temp.d[1]);
            read(temp.val);
            temp.d[0]^=last_ans,temp.d[1]^=last_ans,temp.val^=last_ans;
            insert(root,temp,0);
        }
        if(opt==2)
        {
            read(X1),read(Y1),read(X2),read(Y2);
            X1^=last_ans,X2^=last_ans,Y1^=last_ans,Y2^=last_ans;
            if(X1>X2)swap(X1,X2);
            if(Y1>Y2)swap(Y1,Y2);
            last_ans=query(root);
            cout<<last_ans<<endl;
        }
        if(opt==3)break;
    }
    return 0;
}

查询

最经典的一个询问就是最近点查询,例如已经给了你 n 个二维平面上的点,现在给你一个点 Q,问你这 n 个点中距离 Q 最近的是哪个。

关于这个距离,欧几里得距离曼哈顿距离 的查询方法都是一样的,只是计算上有所不同,具体的可以例题上解释。

我们回归正题,这个最近点该怎么查询呢?很显然,直接遍历 n 个点是很不现实的,我们考虑如何从k-D Tree上搜到答案。

剪枝

假如我们当前搜到了k-D Tree上的点 p,先更新一下最短距离 ans=min(ans,dis(p,Q))

接下来递归左子树和右子树,为了降低我们的时间复杂度,我们可以剪枝,具体是这样的。

  1. 分别计算出点 Q 到点 p 的左子树和右子树所管辖矩形的距离,记为 dislsdisrs
  2. 比较 dislsdisrs,如果 disls\lt disrs,则先递归左子树,否则先递归右子树。
void query(int p)
{
    ans=min(ans,dis(p));
    int disls=INF;
    int disrs=INF;
    if(t[p].ls)disls=dismatrix(t[p].ls);//计算查询点到左子树管辖矩形的距离
    if(t[p].rs)disrs=dismatrix(t[p].rs);//计算查询点到右子树管辖矩形的距离
    if(disls<disrs)
    {
        if(disls<ans)query(t[p].ls);//先递归左子树
        if(disrs<ans)query(t[p].rs);
    }
    else
    {
        if(disrs<ans)query(t[p].rs);//先递归右子树
        if(disls<ans)query(t[p].ls);
    }
}

可以发现,如果我们先递归左子树,更新了答案以后,如果右子树管辖矩形到查询点的距离大于已更新的答案,就直接跳过了,这样我们就完美地做到了剪枝。

那么,来看一道例题吧!

洛谷 P4169 [Violet] 天使玩偶/SJY摆棋子

在二维平面上,给出 n 个点 (x,y),以及 m 个操作。

1 x y 添加一个点 (x,y)

2 x y 查询所有点中,距离 (x,y) 最近的一个点。

本题中,距离指曼哈顿距离,即 dist(A,B)=|A_x-B_x|+|A_y-B_y|

本题有插入操作,不能静态建树,需要考虑插入新点后是否重构。

现在考虑如何查询,其实就是上面那个过程,主要难点在如何求出某个点到矩形的曼哈顿距离。

看这张美丽的图,假设 A,B 是我们给出的查询点,这个矩形是某个子树所管辖的矩形,看图稍微推一下,我们可以这样计算它们的距离:

int dismatrix(int p)
{
    //(X,Y)为查询点
    int res=0;
    res+=max(0,X-t[p].mx[0])+max(0,t[p].mn[0]-X);
    res+=max(0,Y-t[p].mx[1])+max(0,t[p].mn[1]-Y);
    return res;
}

对于更高维的,我们给出这样一个公式

dist(T,P)=\sum_{i=0}^{K-1}max(0,T_i-P_{mx_i})+max(0,P_{mn_i}-T_i) 完整代码 **(不要复制哦)** ```cpp #include<bits/stdc++.h> using namespace std; const int MAX=1e6+5; template<typename T> void read(T &x) { x=0; int f=1; char ch=getchar(); while(!isdigit(ch)) { if(ch=='-')f=-1; ch=getchar(); } while(isdigit(ch)) { x=x*10+(ch^48); ch=getchar(); } x*=f; } int K=0,cur=0,cnt=0,X,Y,ans; double A=0.75; struct node { int d[2]; int mn[2],mx[2]; int sz,ls,rs; }t[MAX]; int g[MAX]; bool cmp(int a,int b) { return t[a].d[K]<t[b].d[K]; } void update(int p) { int ls=t[p].ls; int rs=t[p].rs; for(int i=0;i<=1;i++) { t[p].mx[i]=t[p].mn[i]=t[p].d[i]; if(ls) { t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]); t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]); } if(rs) { t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]); t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]); } } t[p].sz=t[ls].sz+t[rs].sz+1; } void pia(int p) { if(!p)return; pia(t[p].ls); g[++cnt]=p; pia(t[p].rs); } int rebuild(int l,int r,int k) { if(l>r)return 0; K=k; int mid=(l+r)>>1; nth_element(g+l,g+mid,g+r+1,cmp); t[g[mid]].ls=rebuild(l,mid-1,k^1); t[g[mid]].rs=rebuild(mid+1,r,k^1); update(g[mid]); return g[mid]; } void check(int &p,int k) { int ls=t[p].ls; int rs=t[p].rs; if(t[p].sz*A<max(t[ls].sz,t[rs].sz)) { cnt=0; pia(p); p=rebuild(1,cnt,k); } } void insert(int &p,node temp,int k) { if(!p) { p=++cur; t[p]=temp; t[p].ls=t[p].rs=0; update(p); return; } if(temp.d[k]<=t[p].d[k])insert(t[p].ls,temp,k^1); else insert(t[p].rs,temp,k^1); update(p); check(p,k); } int dismatrix(int p) { int res=0; res+=max(0,X-t[p].mx[0])+max(0,t[p].mn[0]-X); res+=max(0,Y-t[p].mx[1])+max(0,t[p].mn[1]-Y); return res; } int dis(int p) { return abs(t[p].d[0]-X)+abs(t[p].d[1]-Y); } void query(int p) { ans=min(ans,dis(p)); int disls=1e9; int disrs=1e9; if(t[p].ls)disls=dismatrix(t[p].ls); if(t[p].rs)disrs=dismatrix(t[p].rs); if(disls<disrs) { if(disls<ans)query(t[p].ls); if(disrs<ans)query(t[p].rs); } else { if(disrs<ans)query(t[p].rs); if(disls<ans)query(t[p].ls); } } int main() { //freopen("tmp.in","r",stdin); //freopen("tmp.out","w",stdout); int n,m,root=0; read(n),read(m); int opt; for(int i=1;i<=n;i++) { node temp; read(temp.d[0]),read(temp.d[1]); insert(root,temp,0); } for(int i=1;i<=m;i++) { read(opt); if(opt==1) { node temp; read(temp.d[0]),read(temp.d[1]); insert(root,temp,0); } if(opt==2) { read(X),read(Y); ans=1e9; query(root); printf("%d\n",ans); } } return 0; } ``` ## K-D Tree进阶 了解了k-D Tree的基本用法,接下来看 $3$ 个例题,来提升一下对k-D Tree的理解,同时学习一下其它的高级查询。 ### 最小最大距离最小距离差 #### 洛谷 P2479 [SDOI2010] 捉迷藏 二维平面上,给出 $n$ 个点 $(x,y)$ ,请找出一个点,使得该点到其它点的最大距离和该点到其它点的最小距离之差最小,输出编号。 **本题中,距离指曼哈顿距离**,即 $dist(A,B)=|A_x-B_x|+|A_y-B_y|

由于没有插入操作,静态建树即可。

假设我们现在只查询最小距离,直接枚举 n 个点,依次查询最小距离即可,注意要给枚举的这个点打个标记,在查询过程中不要查询这个点。

最大距离其实也很容易解决,和查询最小距离反着来即可,注意这时查询点到矩形的距离要取更远的那一个,否则会漏掉答案。

非常好,那就简单了,我们直接枚举 n 个点,查询出到每个点的最大距离和最小距离,直接更新答案,即 ans=min(ans,maxdis-mindis)

代码实现起来比较麻烦,细节较多。

放一下代码 (不要复制哦)

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAX=1e5+5;
const int INF=1e18;
template<typename T>
void read(T &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=x*10+(ch^48);
        ch=getchar();
    }
    x*=f;
}
int K=0,maxdis,mindis,X,Y;
int ABS(int x)
{
    return (x<0?(-x):x);
}
struct node
{
    int d[2],mx[2],mn[2];
    int ls,rs;
    bool del=0;
}t[MAX],ori[MAX];
bool cmp(node a,node b)
{
    return a.d[K]<b.d[K];
}
void update(int p)
{
    int ls=t[p].ls;
    int rs=t[p].rs;
    for(int i=0;i<=1;i++)
    {
        t[p].mx[i]=t[p].mn[i]=t[p].d[i];
        if(ls)
        {
            t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
        }
        if(rs)
        {
            t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
        }
    }
}
int build(int l,int r,int k)
{
    if(l>r)return 0;
    K=k;
    int mid=(l+r)>>1;
    nth_element(ori+l,ori+mid,ori+r+1,cmp);
    t[mid]=ori[mid];
    t[mid].ls=build(l,mid-1,k^1);
    t[mid].rs=build(mid+1,r,k^1);
    update(mid);
    return mid;
}
int dis(int p)
{
    return ABS(t[p].d[0]-X)+ABS(t[p].d[1]-Y);
}
int dismatrixmin(int p)
{
    int res=0;
    res+=max((long long)0,X-t[p].mx[0])+max((long long)0,t[p].mn[0]-X);
    res+=max((long long)0,Y-t[p].mx[1])+max((long long)0,t[p].mn[1]-Y);
    return res;
}
int dismatrixmax(int p)
{
    int res=0;
    res+=max((long long)0,t[p].mx[0]-X)+max((long long)0,X-t[p].mn[0]);
    res+=max((long long)0,t[p].mx[1]-Y)+max((long long)0,Y-t[p].mn[1]);
    return res;
}
void querymin(int p)
{
    if(!t[p].del)mindis=min(mindis,dis(p));
    int dl=INF,dr=INF;
    if(t[p].ls)dl=dismatrixmin(t[p].ls);
    if(t[p].rs)dr=dismatrixmin(t[p].rs);
    if(dl<dr)
    {
        if(dl<mindis)querymin(t[p].ls);
        if(dr<mindis)querymin(t[p].rs);
    }
    else
    {
        if(dr<mindis)querymin(t[p].rs);
        if(dl<mindis)querymin(t[p].ls);
    }
}
void querymax(int p)
{
    if(!t[p].del)maxdis=max(maxdis,dis(p));
    int dl=-INF,dr=-INF;
    if(t[p].ls)dl=dismatrixmax(t[p].ls);
    if(t[p].rs)dr=dismatrixmax(t[p].rs);
    if(dl>dr)
    {
        if(dl>maxdis)querymax(t[p].ls);
        if(dr>maxdis)querymax(t[p].rs);
    }
    else
    {
        if(dr>maxdis)querymax(t[p].rs);
        if(dl>maxdis)querymax(t[p].ls);
    }
}
signed main()
{
    int n,root=0,ans=INF;
    read(n);
    for(int i=1;i<=n;i++)read(ori[i].d[0]),read(ori[i].d[1]);
    root=build(1,n,0);
    for(int i=1;i<=n;i++)
    {
        t[i].del=1;
        X=t[i].d[0],Y=t[i].d[1];
        maxdis=-INF,mindis=INF;
        querymin(root),querymax(root);
        /*cout<<t[i].mn[0]<<" "<<t[i].mn[1]<<" "<<t[i].mx[0]<<" "<<t[i].mx[1]<<endl;
        cout<<"ls"<<t[i].ls<<endl;
        cout<<"rs"<<t[i].rs<<endl;*/
        ans=min(ans,maxdis-mindis);
        t[i].del=0;
    }
    cout<<ans<<endl;
    return 0;
}

k 远距离查询

洛谷 P2093 [国家集训队] JZPFAR

二维平面上有 n 个点 (x,y),现有 m 次询问,每次询问给出一个点 (px,py) 以及 k,查询这 n 个点中到点 (px,py) 的距离第 k 大的点的编号。如果有多个点到 (px,py) 的距离相同,那么认定编号较小的点距离较大。

本题中,距离指欧几里得距离,即 dist(A,B)=\sqrt {(A_x-B_x)^2+(A_y-B_y)^2}

是个麻烦题。首先,为了方便,我们在计算距离时可以不用开方,不影响答案。

由于没有插入操作,静态建树即可。

查询最大距离就不用说了,之前提过了,现在考虑怎么统计第 k 远。

我们可以开一个 小根堆,先插入 k 个极小值,然后从根节点开始查询,如果 (px,py) 到当前节点的距离大于堆顶存的距离,或者 (px,py) 到当前节点的距离等于堆顶存的距离,但是该点编号小于堆顶存的编号,我们就弹出堆顶,把现在这个点到 (px,py) 的距离和这个点的编号丢进去。查询结束后,堆顶存的编号就是答案。

具体过程如下:

struct ans
{
    int dis,id;
};
bool operator <(ans a,ans b)
{
    if(a.dis==b.dis)return a.id<b.id;
    return a.dis>b.dis;
}
priority_queue<ans>q;
void query(int p)
{
    if(!p)return;
    int dis=getdis(t[p].d[0],t[p].d[1],X,Y);//获取当前点到查询点的距离
    if(dis>q.top().dis || (dis==q.top().dis && t[p].id<q.top().id))//距离更远
    {
        q.pop();
        q.push({dis,t[p].id});//把更远的丢进堆
    }
    //剪枝优化
    int dl=-INF,dr=-INF;
    if(t[p].ls)dl=dismatrix(t[p].ls);
    if(t[p].rs)dr=dismatrix(t[p].rs);
    if(dl>dr)
    {
        if(dl>=q.top().dis)query(t[p].ls);
        if(dr>=q.top().dis)query(t[p].rs);
    }
    else
    {
        if(dr>=q.top().dis)query(t[p].rs);
        if(dl>=q.top().dis)query(t[p].ls);
    }
}

完整代码 (不要复制哦)

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAX=1e6+5;
const int INF=1e18;
template<typename T>
void read(T &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=x*10+(ch^48);
        ch=getchar();
    }
    x*=f;
}
struct node
{
    int d[2],mx[2],mn[2];
    int ls,rs;
    int id;
}t[MAX],ori[MAX];
struct ans
{
    int dis,id;
};
bool operator <(ans a,ans b)
{
    if(a.dis==b.dis)return a.id<b.id;
    return a.dis>b.dis;
}
priority_queue<ans>q;
int K=0,X,Y;
bool cmp(node a,node b)
{
    return a.d[K]<b.d[K];
}
void update(int p)
{
    int ls=t[p].ls;
    int rs=t[p].rs;
    for(int i=0;i<=1;i++)
    {
        t[p].mx[i]=t[p].mn[i]=t[p].d[i];
        if(ls)
        {
            t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
        }
        if(rs)
        {
            t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
        }
    }
}
int build(int l,int r,int k)
{
    if(l>r)return 0;
    K=k;
    int mid=(l+r)>>1;
    nth_element(ori+l,ori+mid,ori+r+1,cmp);
    t[mid]=ori[mid];
    t[mid].ls=build(l,mid-1,k^1);
    t[mid].rs=build(mid+1,r,k^1);
    update(mid);
    return mid;
}
int getdis(int x1,int y1,int x2,int y2)
{
    int k1=(x1-x2)*(x1-x2);
    int k2=(y1-y2)*(y1-y2);
    return k1+k2;
}
int dismatrix(int p)
{
    int dis=-INF;
    dis=max(dis,getdis(t[p].mn[0],t[p].mn[1],X,Y));
    dis=max(dis,getdis(t[p].mx[0],t[p].mn[1],X,Y));
    dis=max(dis,getdis(t[p].mx[0],t[p].mx[1],X,Y));
    dis=max(dis,getdis(t[p].mn[0],t[p].mx[1],X,Y));
    return dis;
}
void query(int p)
{
    if(!p)return;
    int dis=getdis(t[p].d[0],t[p].d[1],X,Y);
    if(dis>q.top().dis || (dis==q.top().dis && t[p].id<q.top().id))
    {
        q.pop();
        q.push({dis,t[p].id});
    }
    int dl=-INF,dr=-INF;
    if(t[p].ls)dl=dismatrix(t[p].ls);
    if(t[p].rs)dr=dismatrix(t[p].rs);
    if(dl>dr)
    {
        if(dl>=q.top().dis)query(t[p].ls);
        if(dr>=q.top().dis)query(t[p].rs);
    }
    else
    {
        if(dr>=q.top().dis)query(t[p].rs);
        if(dl>=q.top().dis)query(t[p].ls);
    }
}
signed main()
{
    int n,m,k;
    read(n);
    for(int i=1;i<=n;i++){read(ori[i].d[0]),read(ori[i].d[1]);ori[i].id=i;}
    int root=build(1,n,0);
    read(m);
    for(int i=1;i<=m;i++)
    {
        read(X),read(Y),read(k);
        while(!q.empty())q.pop();
        for(int i=1;i<=k;i++)q.push({-INF,INF});
        query(root);
        cout<<q.top().id<<endl;
    }
    return 0;
}

圆的相交问题

洛谷 P4631 [APIO2018] 选圆圈

在二维平面上,有 n 个圆,记为 c_1,c_2,...,c_n。执行以下操作:

  1. 找到这些圆中半径最大的。如果有多个半径最大的圆,选择编号最小的。记为 c_i
  2. 删除 c_i 及与其有交集的所有圆。两个圆有交集当且仅当平面上存在一个点,这个点同时在这两个圆的圆周上或圆内。
  3. 重复上面两个步骤直到所有的圆都被删除。

c_i 被删除时,若循环中第 1 步选择的圆是 c_j,我们说 c_ic_j 删除。对于每个圆,求出它是被哪一个圆删除的。

样例:

11
9 9 2
13 2 1
11 8 2
3 3 2
3 12 1
12 14 1
9 8 5
2 8 2
5 2 1
14 4 2
14 14 1
7 2 7 4 5 6 7 7 4 7 6

样例解释:

有圆,不太好处理,看起来无从下手?

我们不妨把每个圆看成矩形(即边长为 2r 的正方形),那么很容易得出,如果两个矩形没有交集,那么这两个矩形代表的圆一定没有交集。

那么怎么储存每个圆的信息呢?还是k-D Tree,这时我们每个节点维护的信息比较多。

struct circle
{
    int d[2];//圆的坐标
    int r,id;//半径和编号
}c[MAX];
struct node
{
    circle dat;//该点代表的圆
    int mx[2],mn[2];//近似矩形
    int ls,rs;//左右儿子
}t[MAX];

矩形很好维护,只是四个顶点的坐标和之前有所不同,用圆的半径和坐标计算即可。

由于没有插入操作,静态建树即可。

接下来是比较关键的内容,如何查询哪些圆与当前圆相交。

首先是判断圆与圆是否有交集,比较好推,也是一个比较常见的结论。

对于两个圆 A,B,它们圆心的距离为 dist(A,B)=\sqrt {(A_x-B_x)^2+(A_y-B_y)^2},如果它们的半径和大于等于圆心之间的距离,说明它们相交或者相切,即有交集。为了防止出现精度问题,我们可以在式子两边平方一下,得到判定式:

(A_x-B_x)^2+(A_y-B_y)^2\leq (A_r+B_r)^2

假设我们当前查询到了k-D Tree上的点 p,如果 p 管辖的矩形和查询圆的近似矩形没有交集,说明 p 子树内的所有圆一定和查询圆没有交集,直接返回。否则就看 p 维护的圆和查询圆是否有交集,同时 p 点维护的圆没有被删除,那就统计答案,并递归左右子树继续统计答案。

至于怎么判断两个矩形是否相交,条件较多,我们不妨判断两个矩形是否不相交:

完整代码 (不要复制哦)

#include<bits/stdc++.h>
using namespace std;
#define int long long
template<typename T>
void read(T &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=x*10+(ch^48);
        ch=getchar();
    }
    x*=f;
}
const int MAX=1e6+5;
struct circle
{
    int d[2];
    int r,id;
}c[MAX];
struct node
{
    circle dat;
    int mx[2],mn[2];
    int ls,rs;
}t[MAX];
int K=0;
int ans[MAX];
bool cmp(circle a,circle b)
{
    return a.d[K]<b.d[K];
}
bool recmp(circle a,circle b)
{
    if(a.r==b.r)return a.id<b.id;
    return a.r>b.r;
}
void update(int p)
{
    int ls=t[p].ls;
    int rs=t[p].rs;
    for(int i=0;i<=1;i++)
    {
        t[p].mn[i]=t[p].dat.d[i]-t[p].dat.r;
        t[p].mx[i]=t[p].dat.d[i]+t[p].dat.r;
        if(ls)
        {
            t[p].mx[i]=max(t[p].mx[i],t[ls].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[ls].mn[i]);
        }
        if(rs)
        {
            t[p].mx[i]=max(t[p].mx[i],t[rs].mx[i]);
            t[p].mn[i]=min(t[p].mn[i],t[rs].mn[i]);
        }
    }
}
int build(int l,int r,int k)
{
    if(l>r)return 0;
    K=k;
    int mid=(l+r)>>1;
    nth_element(c+l,c+mid,c+r+1,cmp);
    t[mid].dat=c[mid];
    t[mid].ls=build(l,mid-1,k^1);
    t[mid].rs=build(mid+1,r,k^1);
    update(mid);
    return mid;
}
bool intersect(circle a,circle b)
{
    int x=a.d[0];
    int y=a.d[1];
    int r=a.r;
    int sec1=(x-b.d[0])*(x-b.d[0]);
    int sec2=(y-b.d[1])*(y-b.d[1]);
    int sec3=(r+b.r)*(r+b.r);
    if(sec1+sec2<=sec3)return 1;
    return 0;
}
bool dismeet(int p,circle temp)
{
    int L=temp.d[0]-temp.r;
    int R=temp.d[0]+temp.r;
    int U=temp.d[1]+temp.r;
    int D=temp.d[1]-temp.r;
    if(R<t[p].mn[0])return 1;
    if(U<t[p].mn[1])return 1;
    if(L>t[p].mx[0])return 1;
    if(D>t[p].mx[1])return 1;
    return 0;
}
void query(int p,circle temp)
{
    if(dismeet(p,temp))return;
    if(intersect(t[p].dat,temp) && !ans[t[p].dat.id])ans[t[p].dat.id]=temp.id;
    int ls=t[p].ls;
    int rs=t[p].rs;
    if(ls)query(ls,temp);
    if(rs)query(rs,temp);
}
signed main()
{
    int n;
    read(n);
    for(int i=1;i<=n;i++)
    {
        read(c[i].d[0]),read(c[i].d[1]),read(c[i].r);
        c[i].id=i;
    }
    int root=build(1,n,0);
    sort(c+1,c+n+1,recmp);
    for(int i=1;i<=n;i++)if(!ans[c[i].id])query(root,c[i]);
    for(int i=1;i<=n;i++)cout<<ans[i]<<" ";
    return 0;
}

总结

如果你想牢记这个算法,最好按照我的写法来,记忆起来非常方便,并且思路比较清晰。感谢观看