K-D Tree入门
K-D Tree(KDT, k-Dimension Tree) 是一种快速解决多维信息维护的数据结构。
以下假设
引入
在二叉搜索树中,我们对一个集合通过一个分水岭元素分成两个集合,一个集合中的元素都比分水岭元素要小,另一个集合中的元素都比分水岭元素要大。之后依照这样建成一颗二叉树。之后我们就可以在这棵树上进行各种操作。
因为二叉搜索树中每个元素都是一个标量,可以视为维数为
建树
二叉搜索树中的元素只有
- 若当前超正方体中只有一个点,直接返回。
- 选择一个维度,再在这一维度选择一个分界点,分成两个超正方体。一个超正方体中这一维的值小于或等于分界点的这一维(作为这个点的左子树),另一个超正方体中这一维的值大于(或者前面小于,这里大于等于)分界点的这一维。
- 对两个超正方体代表的子树再次进行建树。
这样我们就可以建出一颗 K-D Tree 了。为了方便理解,这里搬 OI Wiki 上的一张图:
图中红色的线表示第一次划分的线,蓝色表示第二次划分的线。上面哪些点根据这样的划分建出来的 K-D Tree 长这个样子(也是从 OI Wiki 盗的):
优化
但是直接用上面的方式建出来的树对查询操作极不友好。我们需要用一些优化让复杂度正确:
- 每次选择划分点的时候选择这一维的中位数,这样树高就是
O(\lg n) 的了。 - 选择划分的维度沿
k 个维度依次选择,这是为了查询的复杂度正确。
在加上以上两个优化之后,我们可以得到一个树高
接着解决如何快速找到中位数的问题。如果直接 sort 一遍,复杂度是 nth_element,可以在 nth_element(s + l, s + mid, s + r + 1, cmp)。
注意:在划分某一维的时候,点中这一维可能会有重复的数,所以我们可能会把两个在这一维相同的数划分到两个子树中。这点在查询的时候要格外注意。(否则树高就不保证是
于是加上这些优化之后 K-D Tree 的建树复杂度就成为了
具体实现看最后的代码。
查询
这里主要讲一下查询一个超正方体中所包含的元素值。我们仿照线段树的思路进行分治查询就行了。具体来说,就是先判断与查询超正方体有没有交点,如果没有就返回;否则判断是否被查询超正方体完全包含,如果完全包含,就直接把这个答案统计进去,否则递归两个子树查询。
时间复杂度我也不会证明,自己去看 OI-Wiki 吧。总之复杂度是
具体实现也看最后的代码。
修改
我们发现 K-D Tree 貌似不支持修改操作,而且又不能像平衡树一样用各种奇奇怪怪的方式解决,只能重构,那么复杂度还不如暴力。
修改主要有两种解决方案:根号重构和二进制分组。具体怎么选择看你自己。反正二进制分组吊打根号重构就是了。
根号重构
我们选择不每一次插入都进行一次重构,而是进行
于是插入的均摊复杂度很明显是
二进制分组
如果点的个数为
于是均摊复杂度就是
查询的时候,就分别在这些树上分别查询,于是单次查询的复杂度就是
代码实现
从复杂度不难发现二进制分组的效率要高得多,所以我写的是二进制分组。
模板题 AC 代码(K-D Tree 模板):
#include <algorithm>
#include <initializer_list>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
constexpr int K = 2, N = 2e5 + 10, KN = 20;
class Position {
public:
Position() {
for (int i = 0; i < K; i++) dat[i] = 0;
}
Position(initializer_list<int> pos) {
auto it = pos.begin();
for (int i = 0; i < K; i++, it++) dat[i] = *it;
}
Position(const Position& other) {
for (int i = 0; i < K; i++) dat[i] = other.dat[i];
}
const int& operator[](int index) const {
return dat[index];
}
int& operator[](int index) {
return dat[index];
}
void operator=(const Position& other) {
for (int i = 0; i < K; i++) dat[i] = other.dat[i];
}
bool operator==(const Position& other) const {
for (int i = 0; i < K; i++)
if (dat[i] != other.dat[i])
return false;
return true;
}
private:
int dat[K];
};
bool contains(const Position& lpos, const Position& rpos, const Position& pos) {
bool flag = true;
for (int i = 0; i < K; i++) flag = flag && (lpos[i] <= pos[i] && pos[i] <= rpos[i]);
return flag;
}
bool contains(const Position& lpos, const Position& rpos, const Position& qlpos, const Position& qrpos) {
bool flag = true;
for (int i = 0; i < K; i++) flag = flag && (lpos[i] <= qlpos[i] && qrpos[i] <= rpos[i]);
return flag;
}
bool notintersect(const Position& lpos, const Position& rpos, const Position& qlpos, const Position& qrpos) {
bool flag = false;
for (int i = 0; i < K; i++) flag = flag || (qrpos[i] < lpos[i] || qlpos[i] > rpos[i]);
return flag;
}
ostream& operator<<(ostream& output, Position pos) {
output << "[";
for (int i = 0; i < K; i++) output << pos[i] << (i == K - 1 ? "" : ", ");
output << "]";
return output;
}
int trushCan[N], trushTop = 0, curIdx;
void poolPush(int pos) {
trushCan[++trushTop] = pos;
}
int poolGet() {
if (trushTop) return trushCan[trushTop--];
return ++curIdx;
}
struct Node {
int val;
Position pos;
int sum;
Position lpos, rpos;
int lson, rson, fg;
Node() : val(0), pos(), sum(0), lpos(), rpos(), lson(0), rson(0), fg(0) {
}
Node(int val, Position pos) : val(val), pos(pos), sum(val), lpos(), rpos(), lson(0), rson(0), fg(0) {
}
} tree[N];
void pushup(int x) {
tree[x].sum = tree[tree[x].lson].sum + tree[tree[x].rson].sum + tree[x].val;
tree[x].lpos = tree[x].pos, tree[x].rpos = tree[x].pos;
for (int i = 0; i < K; i++) {
if (tree[x].lson) {
tree[x].lpos[i] = min(tree[x].lpos[i], tree[tree[x].lson].lpos[i]);
tree[x].rpos[i] = max(tree[x].rpos[i], tree[tree[x].lson].rpos[i]);
}
if (tree[x].rson) {
tree[x].lpos[i] = min(tree[x].lpos[i], tree[tree[x].rson].lpos[i]);
tree[x].rpos[i] = max(tree[x].rpos[i], tree[tree[x].rson].rpos[i]);
}
}
}
void print(int x, int cs) {
for (int i = 0; i < cs; i++) cout << '-';
cout << "At[" << x << "], val=" << tree[x].val << ", pos=" << tree[x].pos << "; sum=" << tree[x].sum << ", lpos=" << tree[x].lpos << ", rpos=" << tree[x].rpos << "; lson=" << tree[x].lson << ", rson=" << tree[x].rson << ", fg=" << tree[x].fg << endl;
if (tree[x].lson) print(tree[x].lson, cs + 1);
if (tree[x].rson) print(tree[x].rson, cs + 1);
}
int mergeTop;
pair<Position, int> mergeTmp[N];
void expend(int cur) {
if (tree[cur].lson) expend(tree[cur].lson);
poolPush(cur);
mergeTmp[++mergeTop] = make_pair(tree[cur].pos, tree[cur].val);
if (tree[cur].rson) expend(tree[cur].rson);
}
int rebuild(int l, int r, int curfg) {
if (l > r) return 0;
int x = poolGet();
if (l == r) {
tree[x].pos = mergeTmp[l].first;
tree[x].val = mergeTmp[l].second;
tree[x].fg = curfg;
tree[x].lson = tree[x].rson = 0;
pushup(x);
return x;
}
int mid = (l + r + 1) >> 1;
nth_element(mergeTmp + l, mergeTmp + mid, mergeTmp + r + 1, [&](const pair<Position, int>& a, const pair<Position, int>& b) { return a.first[curfg] < b.first[curfg]; });
tree[x].pos = mergeTmp[mid].first;
tree[x].val = mergeTmp[mid].second;
tree[x].fg = curfg;
tree[x].lson = rebuild(l, mid - 1, (curfg + 1) % K);
tree[x].rson = rebuild(mid + 1, r, (curfg + 1) % K);
pushup(x);
return x;
}
int merge(int ra, int rb) {
mergeTop = 0;
expend(ra);
expend(rb);
sort(mergeTmp + 1, mergeTmp + mergeTop + 1, [](pair<Position, int> a, pair<Position, int> b) {
for (int i = 0; i < K; i++)
if (a.first[i] != b.first[i])
return a.first[i] < b.first[i];
return a.second < b.second;
});
int j = 0;
for (int i = 1; i <= mergeTop; i++) {
if (j && mergeTmp[i].first == mergeTmp[j].first) {
mergeTmp[j].second += mergeTmp[i].second;
} else {
mergeTmp[++j] = mergeTmp[i];
}
}
mergeTop = j;
return rebuild(1, mergeTop, 0);
}
int roots[KN];
void push(Position pos, int val) {
mergeTmp[mergeTop = 1] = {pos, val};
int rt = rebuild(1, 1, 0);
int i = 0;
for (; i < KN && roots[i]; i++) {
rt = merge(rt, roots[i]);
roots[i] = 0;
}
roots[i] = rt;
}
int queryone(int cur, Position& lpos, Position& rpos) {
if (!cur) return 0;
if (notintersect(lpos, rpos, tree[cur].lpos, tree[cur].rpos)) return 0;
if (contains(lpos, rpos, tree[cur].lpos, tree[cur].rpos)) return tree[cur].sum;
int ans = 0;
if (contains(lpos, rpos, tree[cur].pos)) ans += tree[cur].val;
ans += queryone(tree[cur].lson, lpos, rpos);
ans += queryone(tree[cur].rson, lpos, rpos);
return ans;
}
int query(Position lpos, Position rpos) {
int ans = 0;
for (int i = 0; i < KN; i++) ans += queryone(roots[i], lpos, rpos);
return ans;
}
int main() {
int n, lastans = 0;
cin >> n;
while (true) {
int opt;
cin >> opt;
if (opt == 1) {
int x, y, v;
cin >> x >> y >> v;
x ^= lastans, y ^= lastans, v ^= lastans;
push({x, y}, v);
} else if (opt == 2) {
int x1, y1, x2, y2;
cin >> x1 >> y1 >> x2 >> y2;
x1 ^= lastans, y1 ^= lastans, x2 ^= lastans, y2 ^= lastans;
cout << (lastans = query({x1, y1}, {x2, y2})) << endl;
} else {
break;
}
}
return 0;
}
2-D Tree 模板
#include <algorithm>
#include <iostream>
#define endl '\n'
using namespace std;
constexpr int N = 2e5 + 10, K = 21;
struct Pos {
int x, y;
};
bool contains(Pos lpos, Pos rpos, Pos pos) { return lpos.x <= pos.x && pos.x <= rpos.x && lpos.y <= pos.y && pos.y <= rpos.y; }
bool contains(Pos lpos, Pos rpos, Pos qlpos, Pos qrpos) { return lpos.x <= qlpos.x && qrpos.x <= rpos.x && lpos.y <= qlpos.y && qrpos.y <= rpos.y; }
bool notintersect(Pos lpos, Pos rpos, Pos qlpos, Pos qrpos) { return lpos.x > qrpos.x || rpos.x < qlpos.x || lpos.y > qrpos.y || rpos.y < qlpos.y; }
int POOL[N], POOL_TOP, IDX;
void poolPush(int u) {
POOL[++POOL_TOP] = u;
}
int poolGet() {
if (POOL_TOP) return POOL[POOL_TOP--];
return ++IDX;
}
struct Node {
Pos lpos, rpos, pos;
int val, sum, lson, rson, fg;
} e[N];
void pushup(int u) {
e[u].sum = e[e[u].lson].sum + e[e[u].rson].sum + e[u].val;
e[u].lpos = e[u].rpos = e[u].pos;
if (e[u].lson) {
e[u].lpos.x = min(e[u].lpos.x, e[e[u].lson].lpos.x);
e[u].rpos.x = max(e[u].rpos.x, e[e[u].lson].rpos.x);
e[u].lpos.y = min(e[u].lpos.y, e[e[u].lson].lpos.y);
e[u].rpos.y = max(e[u].rpos.y, e[e[u].lson].rpos.y);
}
if (e[u].rson) {
e[u].lpos.x = min(e[u].lpos.x, e[e[u].rson].lpos.x);
e[u].rpos.x = max(e[u].rpos.x, e[e[u].rson].rpos.x);
e[u].lpos.y = min(e[u].lpos.y, e[e[u].rson].lpos.y);
e[u].rpos.y = max(e[u].rpos.y, e[e[u].rson].rpos.y);
}
}
int buildTop;
pair<Pos, int> buildTmp[N];
int build(int l, int r, int curfg) {
if (l > r) return 0;
int u = poolGet();
if (l == r) {
e[u].pos = buildTmp[l].first, e[u].val = buildTmp[l].second;
e[u].lson = e[u].rson = 0;
e[u].fg = curfg;
pushup(u);
return u;
}
int mid = (l + r) >> 1;
nth_element(buildTmp + l, buildTmp + mid, buildTmp + r + 1, [&](const pair<Pos, int>& a, const pair<Pos, int>& b) { return curfg ? a.first.y < b.first.y : a.first.x < b.first.x; });
e[u].pos = buildTmp[mid].first, e[u].val = buildTmp[mid].second;
e[u].lson = build(l, mid - 1, curfg ^ 1);
e[u].rson = build(mid + 1, r, curfg ^ 1);
e[u].fg = curfg;
pushup(u);
return u;
}
void expand(int u) {
if (!u) return;
buildTmp[++buildTop] = {e[u].pos, e[u].val};
poolPush(u);
expand(e[u].lson);
expand(e[u].rson);
}
int merge(int u, int v) {
buildTop = 0;
expand(u), expand(v);
return build(1, buildTop, 0);
}
int roots[K];
void insert(Pos pos, int val) {
buildTmp[buildTop = 1] = {pos, val};
int root = build(1, buildTop, 0), i = 0;
for (; i < K && roots[i]; i++) {
root = merge(root, roots[i]);
roots[i] = 0;
}
roots[i] = root;
}
int queryone(Pos lpos, Pos rpos, int u) {
if (!u || notintersect(lpos, rpos, e[u].lpos, e[u].rpos)) return 0;
if (contains(lpos, rpos, e[u].lpos, e[u].rpos)) return e[u].sum;
int res = 0;
if (contains(lpos, rpos, e[u].pos)) res += e[u].val;
res += queryone(lpos, rpos, e[u].lson) + queryone(lpos, rpos, e[u].rson);
return res;
}
int query(Pos lpos, Pos rpos) {
int res = 0;
for (int i = 0; i < K; i++) res += queryone(lpos, rpos, roots[i]);
return res;
}
int lastans, FW_NUMBER, opt;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> FW_NUMBER;
while ((cin >> opt, opt) != 3) {
if (opt == 1) {
int val;
Pos pos;
cin >> pos.x >> pos.y >> val;
pos.x ^= lastans, pos.y ^= lastans, val ^= lastans;
insert(pos, val);
} else {
Pos lpos, rpos;
cin >> lpos.x >> lpos.y >> rpos.x >> rpos.y;
lpos.x ^= lastans, lpos.y ^= lastans, rpos.x ^= lastans, rpos.y ^= lastans;
cout << (lastans = query(lpos, rpos)) << endl;
}
}
return 0;
}
指针版本代码 from LBY:
#include<bits/stdc++.h>
using namespace std;
int n;
vector<pair<pair<int,int>,int> >re;
inline bool cmp1(pair<pair<int,int>,int>x,pair<pair<int,int>,int>y){return x.first.second<y.first.second;}
inline bool cmp0(pair<pair<int,int>,int>x,pair<pair<int,int>,int>y){return x.first.first<y.first.first;}
struct KDT{
int siz;
struct node{
node *l,*r;
int val,sum,p[2],mx[2],mn[2];
}*root,*null;
inline KDT(){
null=new node;
null->l=null->r=null;
null->val=null->sum=0,root=null,null->mn[0]=null->mn[1]=2e9,null->mx[0]=null->mx[1]=0;
siz=0;
}
inline node *new_node(){
node *p=new node;
p->l=p->r=null,p->val=0;
return p;
}
void update(int t[2],int v,node *&root,bool d){
if(root==null){
siz++;
root=new_node(),root->val=root->sum=v;
root->p[0]=root->mx[0]=root->mn[0]=t[0],root->p[1]=root->mx[1]=root->mn[1]=t[1];
return;
}
if(root->p[0]==t[0] && root->p[1]==t[1]){
root->val+=v,root->sum+=v;
return;
}
if(t[d]<root->p[d]) update(t,v,root->l,d^1);
else update(t,v,root->r,d^1);
root->sum=root->l->sum+root->r->sum+root->val;
root->mx[0]=max(max(root->l->mx[0],root->mx[0]),root->r->mx[0]),root->mx[1]=max(max(root->l->mx[1],root->mx[1]),root->r->mx[1]);
root->mn[0]=min(min(root->l->mn[0],root->mn[0]),root->r->mn[0]),root->mn[1]=min(min(root->l->mn[1],root->mn[1]),root->r->mn[1]);
}
int ask(int u[2],int v[2],node *root){
if(root==null) return 0;
if(root->mn[0]>=u[0] && root->mx[0]<=v[0] && root->mn[1]>=u[1] && root->mx[1]<=v[1]) return root->sum;
int tot=0;
if(root->p[0]>=u[0] && root->p[0]<=v[0] && root->p[1]>=u[1] && root->p[1]<=v[1]) tot=root->val;
if(!(root->l->mx[0]<u[0] || root->l->mn[0]>v[0] || root->l->mx[1]<u[1] || root->l->mn[1]>v[1])) tot+=ask(u,v,root->l);
if(!(root->r->mx[0]<u[0] || root->r->mn[0]>v[0] || root->r->mx[1]<u[1] || root->r->mn[1]>v[1])) tot+=ask(u,v,root->r);
return tot;
}
inline void FREE(node *root){
if(root==null) return;
FREE(root->l),FREE(root->r);
delete root;
}
inline void get_node(node *root){
if(root==null) return;
re.push_back({{root->p[0],root->p[1]},root->val});
get_node(root->l),get_node(root->r);
}
inline node *RECON(int l,int r,int d=0){
if(l>r) return null;
if(l==r){
int t[2]={re[l].first.first,re[l].first.second},v=re[l].second;
node *root=new_node();root->val=root->sum=v;
root->p[0]=root->mx[0]=root->mn[0]=t[0],root->p[1]=root->mx[1]=root->mn[1]=t[1];
return root;
}
if(d) sort(re.begin()+l,re.begin()+r+1,cmp1);else sort(re.begin()+l,re.begin()+r+1,cmp0);
int mid=l+r>>1;
int t[2]={re[mid].first.first,re[mid].first.second},v=re[mid].second;
node *t1=RECON(l,mid-1,d^1),*t2=RECON(mid+1,r,d^1);
node *u=new_node();u->val=v;
u->l=t1,u->r=t2;
u->p[0]=u->mx[0]=u->mn[0]=t[0],u->p[1]=u->mx[1]=u->mn[1]=t[1];
u->mx[0]=max(max(u->l->mx[0],u->mx[0]),u->r->mx[0]),u->mx[1]=max(max(u->l->mx[1],u->mx[1]),u->r->mx[1]);
u->mn[0]=min(min(u->l->mn[0],u->mn[0]),u->r->mn[0]),u->mn[1]=min(min(u->l->mn[1],u->mn[1]),u->r->mn[1]);
u->sum=u->l->sum+u->r->sum+u->val;
return u;
}
inline void recon(){
get_node(root);
FREE(root);
root=RECON(0,re.size()-1);
re.clear();
}
}ttt[50];
signed main() {
cin>>n;
int lans=0;
while(true){
int opt;
cin>>opt;
if(opt==1){
int x,y,a;
cin>>x>>y>>a;
x^=lans,y^=lans,a^=lans;
KDT T;
int t[2]={x,y};
T.update(t,a,T.root,0);
for(int i=0;i<=30;i++) if(ttt[i].siz==0){
ttt[i]=T;
break;
}else ttt[i].get_node(ttt[i].root),ttt[i].FREE(ttt[i].root),T.recon(),ttt[i]=KDT();
}
if(opt==2){
int x,y,x2,y2;
cin>>x>>y>>x2>>y2;
x^=lans,y^=lans,x2^=lans,y2^=lans;
int t1[2]={x,y},t2[2]={x2,y2},ans=0;
for(int i=0;i<=30;i++) ans+=ttt[i].ask(t1,t2,ttt[i].root);
cout<<(lans=ans)<<"\n";
}
if(opt==3) return 0;
}
return 0;
}
练习题
P2479 [SDOI2010] 捉迷藏 - 洛谷
P4169 [Violet] 天使玩偶/SJY摆棋子 - 洛谷
P2093 [国家集训队] JZPFAR - 洛谷
P4390 [BalkanOI 2007] Mokia 摩基亚 - 洛谷
P4475 巧克力王国 - 洛谷
P3769 [CH弱省胡策R2] TATT - 洛谷
P5471 [NOI2019] 弹跳 - 洛谷