K-D Tree入门

· · 算法·理论

K-D Tree(KDT, k-Dimension Tree) 是一种快速解决多维信息维护的数据结构。

以下假设 n 为 K-D Tree 的大小,m 为查询次数,插入次数明显与 n 同阶。

引入

在二叉搜索树中,我们对一个集合通过一个分水岭元素分成两个集合,一个集合中的元素都比分水岭元素要小,另一个集合中的元素都比分水岭元素要大。之后依照这样建成一颗二叉树。之后我们就可以在这棵树上进行各种操作。

因为二叉搜索树中每个元素都是一个标量,可以视为维数为 1 的向量或 1 维数轴上的点。所以二叉搜索树事实上就是维护了一维空间中的点。我们尝试把 1 维拓展到 k 维,维护 k 维的点,并仿照二叉搜索树的建立方式去维护一个数据结构,就是 K-D Tree。

建树

二叉搜索树中的元素只有 1 维,因此划分直接按照这一维划分。而 K-D Tree 有 k 维,所以我们需要从 k 维之中选择 1 维进行划分。于是我们就可以列出建树的流程:

这样我们就可以建出一颗 K-D Tree 了。为了方便理解,这里搬 OI Wiki 上的一张图:

图中红色的线表示第一次划分的线,蓝色表示第二次划分的线。上面哪些点根据这样的划分建出来的 K-D Tree 长这个样子(也是从 OI Wiki 盗的):

优化

但是直接用上面的方式建出来的树对查询操作极不友好。我们需要用一些优化让复杂度正确:

在加上以上两个优化之后,我们可以得到一个树高 \lg n+O(1) 的,看起来平衡的 K-D Tree。

接着解决如何快速找到中位数的问题。如果直接 sort 一遍,复杂度是 O(n\lg n) 的。但是我们只需要找一个位置的数就行了,可以转化为序列找第 k 小的问题,可以用快速排序转化一下做,而标准库中也有实现这个功能的函数: nth_element,可以在 O(n) 的时间内找到第 k 大。就可以写成 nth_element(s + l, s + mid, s + r + 1, cmp)

注意:在划分某一维的时候,点中这一维可能会有重复的数,所以我们可能会把两个在这一维相同的数划分到两个子树中。这点在查询的时候要格外注意。(否则树高就不保证是 \lg n+O(1) 的,查询复杂度就伪了)

于是加上这些优化之后 K-D Tree 的建树复杂度就成为了 O(n\lg n) 的。

具体实现看最后的代码。

查询

这里主要讲一下查询一个超正方体中所包含的元素值。我们仿照线段树的思路进行分治查询就行了。具体来说,就是先判断与查询超正方体有没有交点,如果没有就返回;否则判断是否被查询超正方体完全包含,如果完全包含,就直接把这个答案统计进去,否则递归两个子树查询。

时间复杂度我也不会证明,自己去看 OI-Wiki 吧。总之复杂度是 O(n^{1-\frac 1k}+\lg n) 的(这也能看出平衡树就是 1-D Tree)。

具体实现也看最后的代码。

修改

我们发现 K-D Tree 貌似不支持修改操作,而且又不能像平衡树一样用各种奇奇怪怪的方式解决,只能重构,那么复杂度还不如暴力。

修改主要有两种解决方案:根号重构和二进制分组。具体怎么选择看你自己。反正二进制分组吊打根号重构就是了。

根号重构

我们选择不每一次插入都进行一次重构,而是进行 B 次插入后再进行查询。其它的暂时存在数组中,查询的时候遍历一遍判断是否需要计算贡献。

于是插入的均摊复杂度很明显是 O(\frac{n\lg n}{B}) 的,查询是 O(B+n^{1-\frac 1k}) 的。一共的复杂度就是 O(n\frac{n\lg n}{B}+m(B+n^{1-\frac 1k}))。如果 n,m 同阶,那么 B 的最佳取值为 O(\sqrt{n\lg n}),修改复杂度就是 O(\sqrt{n\lg n}),查询复杂度就是 O(\sqrt{n\lg n}+n^{1-\frac1k});否则 B 的最佳取值为 O(\sqrt{\frac nmn\lg n}),修改复杂度就是 O(\frac mn\sqrt{\frac nmn\lg n)},查询复杂度为 O(\sqrt{\frac nmn\lg n}) 的。当然一般取 n,m 同阶的情况就行了。

二进制分组

如果点的个数为 n,我们就可以维护 \lfloor\lg n\rfloor 棵 K-D Tree,其中第 i 棵(从 0 开始计)存储 2^i 个节点。然后仿照二进制进位的方式插入之后进位(合并)。具体来说就是新增一颗大小为 1 的 K-D Tree,然后不断合并大小相同的树。这样插入的总复杂度是

\begin{align} O\left(\sum_{k=0}^{\lfloor\lg n\rfloor}\lfloor\frac{n}{2^k}\rfloor2^k\lg\left(2^k\right)\right)&=O\left(\sum_{k=0}^{\lfloor\lg n\rfloor}\frac{n}{2^k}2^kk\right)\\ &=O\left(\sum_{k=0}^{\lfloor\lg n\rfloor}nk\right)\\ &=O\left(n\sum_{k=0}^{\lfloor\lg n\rfloor}k\right)\\ &=O\left(n\frac{\lfloor\lg n\rfloor(\lfloor\lg n\rfloor+1)}{2}\right)\\ &=O(n\lg^2n) \end{align}

于是均摊复杂度就是 O(\lg^2n) 的了。

查询的时候,就分别在这些树上分别查询,于是单次查询的复杂度就是 O(n^{1-\frac1k}) 的了。

代码实现

从复杂度不难发现二进制分组的效率要高得多,所以我写的是二进制分组。

模板题 AC 代码(K-D Tree 模板):

#include <algorithm>
#include <initializer_list>
#include <iostream>
#include <utility>
#include <vector>

using namespace std;

constexpr int K = 2, N = 2e5 + 10, KN = 20;

class Position {
public:
    Position() {
        for (int i = 0; i < K; i++) dat[i] = 0;
    }

    Position(initializer_list<int> pos) {
        auto it = pos.begin();
        for (int i = 0; i < K; i++, it++) dat[i] = *it;
    }

    Position(const Position& other) {
        for (int i = 0; i < K; i++) dat[i] = other.dat[i];
    }

    const int& operator[](int index) const {
        return dat[index];
    }

    int& operator[](int index) {
        return dat[index];
    }

    void operator=(const Position& other) {
        for (int i = 0; i < K; i++) dat[i] = other.dat[i];
    }

    bool operator==(const Position& other) const {
        for (int i = 0; i < K; i++)
            if (dat[i] != other.dat[i])
                return false;
        return true;
    }

private:
    int dat[K];
};

bool contains(const Position& lpos, const Position& rpos, const Position& pos) {
    bool flag = true;
    for (int i = 0; i < K; i++) flag = flag && (lpos[i] <= pos[i] && pos[i] <= rpos[i]);
    return flag;
}

bool contains(const Position& lpos, const Position& rpos, const Position& qlpos, const Position& qrpos) {
    bool flag = true;
    for (int i = 0; i < K; i++) flag = flag && (lpos[i] <= qlpos[i] && qrpos[i] <= rpos[i]);
    return flag;
}

bool notintersect(const Position& lpos, const Position& rpos, const Position& qlpos, const Position& qrpos) {
    bool flag = false;
    for (int i = 0; i < K; i++) flag = flag || (qrpos[i] < lpos[i] || qlpos[i] > rpos[i]);
    return flag;
}

ostream& operator<<(ostream& output, Position pos) {
    output << "[";
    for (int i = 0; i < K; i++) output << pos[i] << (i == K - 1 ? "" : ", ");
    output << "]";
    return output;
}

int trushCan[N], trushTop = 0, curIdx;

void poolPush(int pos) {
    trushCan[++trushTop] = pos;
}

int poolGet() {
    if (trushTop) return trushCan[trushTop--];
    return ++curIdx;
}

struct Node {
    int val;
    Position pos;
    int sum;
    Position lpos, rpos;
    int lson, rson, fg;

    Node() : val(0), pos(), sum(0), lpos(), rpos(), lson(0), rson(0), fg(0) {
    }

    Node(int val, Position pos) : val(val), pos(pos), sum(val), lpos(), rpos(), lson(0), rson(0), fg(0) {
    }
} tree[N];

void pushup(int x) {
    tree[x].sum = tree[tree[x].lson].sum + tree[tree[x].rson].sum + tree[x].val;
    tree[x].lpos = tree[x].pos, tree[x].rpos = tree[x].pos;
    for (int i = 0; i < K; i++) {
        if (tree[x].lson) {
            tree[x].lpos[i] = min(tree[x].lpos[i], tree[tree[x].lson].lpos[i]);
            tree[x].rpos[i] = max(tree[x].rpos[i], tree[tree[x].lson].rpos[i]);
        }
        if (tree[x].rson) {
            tree[x].lpos[i] = min(tree[x].lpos[i], tree[tree[x].rson].lpos[i]);
            tree[x].rpos[i] = max(tree[x].rpos[i], tree[tree[x].rson].rpos[i]);
        }
    }
}

void print(int x, int cs) {
    for (int i = 0; i < cs; i++) cout << '-';
    cout << "At[" << x << "], val=" << tree[x].val << ", pos=" << tree[x].pos << "; sum=" << tree[x].sum << ", lpos=" << tree[x].lpos << ", rpos=" << tree[x].rpos << "; lson=" << tree[x].lson << ", rson=" << tree[x].rson << ", fg=" << tree[x].fg << endl;
    if (tree[x].lson) print(tree[x].lson, cs + 1);
    if (tree[x].rson) print(tree[x].rson, cs + 1);
}

int mergeTop;
pair<Position, int> mergeTmp[N];

void expend(int cur) {
    if (tree[cur].lson) expend(tree[cur].lson);
    poolPush(cur);
    mergeTmp[++mergeTop] = make_pair(tree[cur].pos, tree[cur].val);
    if (tree[cur].rson) expend(tree[cur].rson);
}

int rebuild(int l, int r, int curfg) {
    if (l > r) return 0;
    int x = poolGet();
    if (l == r) {
        tree[x].pos = mergeTmp[l].first;
        tree[x].val = mergeTmp[l].second;
        tree[x].fg = curfg;
        tree[x].lson = tree[x].rson = 0;
        pushup(x);
        return x;
    }
    int mid = (l + r + 1) >> 1;
    nth_element(mergeTmp + l, mergeTmp + mid, mergeTmp + r + 1, [&](const pair<Position, int>& a, const pair<Position, int>& b) { return a.first[curfg] < b.first[curfg]; });
    tree[x].pos = mergeTmp[mid].first;
    tree[x].val = mergeTmp[mid].second;
    tree[x].fg = curfg;
    tree[x].lson = rebuild(l, mid - 1, (curfg + 1) % K);
    tree[x].rson = rebuild(mid + 1, r, (curfg + 1) % K);
    pushup(x);
    return x;
}

int merge(int ra, int rb) {
    mergeTop = 0;
    expend(ra);
    expend(rb);
    sort(mergeTmp + 1, mergeTmp + mergeTop + 1, [](pair<Position, int> a, pair<Position, int> b) {
        for (int i = 0; i < K; i++)
            if (a.first[i] != b.first[i])
                return a.first[i] < b.first[i];
        return a.second < b.second;
    });
    int j = 0;
    for (int i = 1; i <= mergeTop; i++) {
        if (j && mergeTmp[i].first == mergeTmp[j].first) {
            mergeTmp[j].second += mergeTmp[i].second;
        } else {
            mergeTmp[++j] = mergeTmp[i];
        }
    }
    mergeTop = j;
    return rebuild(1, mergeTop, 0);
}

int roots[KN];

void push(Position pos, int val) {
    mergeTmp[mergeTop = 1] = {pos, val};
    int rt = rebuild(1, 1, 0);
    int i = 0;
    for (; i < KN && roots[i]; i++) {
        rt = merge(rt, roots[i]);
        roots[i] = 0;
    }
    roots[i] = rt;
}

int queryone(int cur, Position& lpos, Position& rpos) {
    if (!cur) return 0;
    if (notintersect(lpos, rpos, tree[cur].lpos, tree[cur].rpos)) return 0;
    if (contains(lpos, rpos, tree[cur].lpos, tree[cur].rpos)) return tree[cur].sum;
    int ans = 0;
    if (contains(lpos, rpos, tree[cur].pos)) ans += tree[cur].val;
    ans += queryone(tree[cur].lson, lpos, rpos);
    ans += queryone(tree[cur].rson, lpos, rpos);
    return ans;
}

int query(Position lpos, Position rpos) {
    int ans = 0;
    for (int i = 0; i < KN; i++) ans += queryone(roots[i], lpos, rpos);
    return ans;
}

int main() {
    int n, lastans = 0;
    cin >> n;
    while (true) {
        int opt;
        cin >> opt;
        if (opt == 1) {
            int x, y, v;
            cin >> x >> y >> v;
            x ^= lastans, y ^= lastans, v ^= lastans;
            push({x, y}, v);
        } else if (opt == 2) {
            int x1, y1, x2, y2;
            cin >> x1 >> y1 >> x2 >> y2;
            x1 ^= lastans, y1 ^= lastans, x2 ^= lastans, y2 ^= lastans;
            cout << (lastans = query({x1, y1}, {x2, y2})) << endl;
        } else {
            break;
        }
    }
    return 0;
}

2-D Tree 模板

#include <algorithm>
#include <iostream>
#define endl '\n'

using namespace std;

constexpr int N = 2e5 + 10, K = 21;

struct Pos {
    int x, y;
};

bool contains(Pos lpos, Pos rpos, Pos pos) { return lpos.x <= pos.x && pos.x <= rpos.x && lpos.y <= pos.y && pos.y <= rpos.y; }
bool contains(Pos lpos, Pos rpos, Pos qlpos, Pos qrpos) { return lpos.x <= qlpos.x && qrpos.x <= rpos.x && lpos.y <= qlpos.y && qrpos.y <= rpos.y; }
bool notintersect(Pos lpos, Pos rpos, Pos qlpos, Pos qrpos) { return lpos.x > qrpos.x || rpos.x < qlpos.x || lpos.y > qrpos.y || rpos.y < qlpos.y; }

int POOL[N], POOL_TOP, IDX;

void poolPush(int u) {
    POOL[++POOL_TOP] = u;
}

int poolGet() {
    if (POOL_TOP) return POOL[POOL_TOP--];
    return ++IDX;
}

struct Node {
    Pos lpos, rpos, pos;
    int val, sum, lson, rson, fg;
} e[N];

void pushup(int u) {
    e[u].sum = e[e[u].lson].sum + e[e[u].rson].sum + e[u].val;
    e[u].lpos = e[u].rpos = e[u].pos;
    if (e[u].lson) {
        e[u].lpos.x = min(e[u].lpos.x, e[e[u].lson].lpos.x);
        e[u].rpos.x = max(e[u].rpos.x, e[e[u].lson].rpos.x);
        e[u].lpos.y = min(e[u].lpos.y, e[e[u].lson].lpos.y);
        e[u].rpos.y = max(e[u].rpos.y, e[e[u].lson].rpos.y);
    }
    if (e[u].rson) {
        e[u].lpos.x = min(e[u].lpos.x, e[e[u].rson].lpos.x);
        e[u].rpos.x = max(e[u].rpos.x, e[e[u].rson].rpos.x);
        e[u].lpos.y = min(e[u].lpos.y, e[e[u].rson].lpos.y);
        e[u].rpos.y = max(e[u].rpos.y, e[e[u].rson].rpos.y);
    }
}

int buildTop;
pair<Pos, int> buildTmp[N];

int build(int l, int r, int curfg) {
    if (l > r) return 0;
    int u = poolGet();
    if (l == r) {
        e[u].pos = buildTmp[l].first, e[u].val = buildTmp[l].second;
        e[u].lson = e[u].rson = 0;
        e[u].fg = curfg;
        pushup(u);
        return u;
    }
    int mid = (l + r) >> 1;
    nth_element(buildTmp + l, buildTmp + mid, buildTmp + r + 1, [&](const pair<Pos, int>& a, const pair<Pos, int>& b) { return curfg ? a.first.y < b.first.y : a.first.x < b.first.x; });
    e[u].pos = buildTmp[mid].first, e[u].val = buildTmp[mid].second;
    e[u].lson = build(l, mid - 1, curfg ^ 1);
    e[u].rson = build(mid + 1, r, curfg ^ 1);
    e[u].fg = curfg;
    pushup(u);
    return u;
}

void expand(int u) {
    if (!u) return;
    buildTmp[++buildTop] = {e[u].pos, e[u].val};
    poolPush(u);
    expand(e[u].lson);
    expand(e[u].rson);
}

int merge(int u, int v) {
    buildTop = 0;
    expand(u), expand(v);
    return build(1, buildTop, 0);
}

int roots[K];

void insert(Pos pos, int val) {
    buildTmp[buildTop = 1] = {pos, val};
    int root = build(1, buildTop, 0), i = 0;
    for (; i < K && roots[i]; i++) {
        root = merge(root, roots[i]);
        roots[i] = 0;
    }
    roots[i] = root;
}

int queryone(Pos lpos, Pos rpos, int u) {
    if (!u || notintersect(lpos, rpos, e[u].lpos, e[u].rpos)) return 0;
    if (contains(lpos, rpos, e[u].lpos, e[u].rpos)) return e[u].sum;
    int res = 0;
    if (contains(lpos, rpos, e[u].pos)) res += e[u].val;
    res += queryone(lpos, rpos, e[u].lson) + queryone(lpos, rpos, e[u].rson);
    return res;
}

int query(Pos lpos, Pos rpos) {
    int res = 0;
    for (int i = 0; i < K; i++) res += queryone(lpos, rpos, roots[i]);
    return res;
}

int lastans, FW_NUMBER, opt;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> FW_NUMBER;
    while ((cin >> opt, opt) != 3) {
        if (opt == 1) {
            int val;
            Pos pos;
            cin >> pos.x >> pos.y >> val;
            pos.x ^= lastans, pos.y ^= lastans, val ^= lastans;
            insert(pos, val);
        } else {
            Pos lpos, rpos;
            cin >> lpos.x >> lpos.y >> rpos.x >> rpos.y;
            lpos.x ^= lastans, lpos.y ^= lastans, rpos.x ^= lastans, rpos.y ^= lastans;
            cout << (lastans = query(lpos, rpos)) << endl;
        }
    }
    return 0;
}

指针版本代码 from LBY:

#include<bits/stdc++.h>
using namespace std;
int n;
vector<pair<pair<int,int>,int> >re;
inline bool cmp1(pair<pair<int,int>,int>x,pair<pair<int,int>,int>y){return x.first.second<y.first.second;}
inline bool cmp0(pair<pair<int,int>,int>x,pair<pair<int,int>,int>y){return x.first.first<y.first.first;}
struct KDT{
    int siz;
    struct node{
        node *l,*r;
        int val,sum,p[2],mx[2],mn[2];
    }*root,*null;
    inline KDT(){
        null=new node;
        null->l=null->r=null;
        null->val=null->sum=0,root=null,null->mn[0]=null->mn[1]=2e9,null->mx[0]=null->mx[1]=0;
        siz=0;
    }
    inline node *new_node(){
        node *p=new node;
        p->l=p->r=null,p->val=0;
        return p;
    }
    void update(int t[2],int v,node *&root,bool d){
        if(root==null){
            siz++;
            root=new_node(),root->val=root->sum=v;
            root->p[0]=root->mx[0]=root->mn[0]=t[0],root->p[1]=root->mx[1]=root->mn[1]=t[1];
            return;
        }
        if(root->p[0]==t[0] && root->p[1]==t[1]){
            root->val+=v,root->sum+=v;
            return;
        }
        if(t[d]<root->p[d]) update(t,v,root->l,d^1);
        else update(t,v,root->r,d^1);
        root->sum=root->l->sum+root->r->sum+root->val;
        root->mx[0]=max(max(root->l->mx[0],root->mx[0]),root->r->mx[0]),root->mx[1]=max(max(root->l->mx[1],root->mx[1]),root->r->mx[1]);
        root->mn[0]=min(min(root->l->mn[0],root->mn[0]),root->r->mn[0]),root->mn[1]=min(min(root->l->mn[1],root->mn[1]),root->r->mn[1]);
    }
    int ask(int u[2],int v[2],node *root){
        if(root==null) return 0;
        if(root->mn[0]>=u[0] && root->mx[0]<=v[0] && root->mn[1]>=u[1] && root->mx[1]<=v[1]) return root->sum;
        int tot=0;
        if(root->p[0]>=u[0] && root->p[0]<=v[0] && root->p[1]>=u[1] && root->p[1]<=v[1]) tot=root->val;
        if(!(root->l->mx[0]<u[0] || root->l->mn[0]>v[0] || root->l->mx[1]<u[1] || root->l->mn[1]>v[1])) tot+=ask(u,v,root->l);
        if(!(root->r->mx[0]<u[0] || root->r->mn[0]>v[0] || root->r->mx[1]<u[1] || root->r->mn[1]>v[1])) tot+=ask(u,v,root->r);
        return tot;
    }
    inline void FREE(node *root){
        if(root==null) return;
        FREE(root->l),FREE(root->r);
        delete root;
    }
    inline void get_node(node *root){
        if(root==null) return;
        re.push_back({{root->p[0],root->p[1]},root->val});
        get_node(root->l),get_node(root->r);
    }
    inline node *RECON(int l,int r,int d=0){
        if(l>r) return null;
        if(l==r){
            int t[2]={re[l].first.first,re[l].first.second},v=re[l].second;
            node *root=new_node();root->val=root->sum=v;
            root->p[0]=root->mx[0]=root->mn[0]=t[0],root->p[1]=root->mx[1]=root->mn[1]=t[1];
            return root;
        }
        if(d) sort(re.begin()+l,re.begin()+r+1,cmp1);else sort(re.begin()+l,re.begin()+r+1,cmp0);
        int mid=l+r>>1;
        int t[2]={re[mid].first.first,re[mid].first.second},v=re[mid].second;
        node *t1=RECON(l,mid-1,d^1),*t2=RECON(mid+1,r,d^1);
        node *u=new_node();u->val=v;
        u->l=t1,u->r=t2;
        u->p[0]=u->mx[0]=u->mn[0]=t[0],u->p[1]=u->mx[1]=u->mn[1]=t[1];
        u->mx[0]=max(max(u->l->mx[0],u->mx[0]),u->r->mx[0]),u->mx[1]=max(max(u->l->mx[1],u->mx[1]),u->r->mx[1]);
        u->mn[0]=min(min(u->l->mn[0],u->mn[0]),u->r->mn[0]),u->mn[1]=min(min(u->l->mn[1],u->mn[1]),u->r->mn[1]);
        u->sum=u->l->sum+u->r->sum+u->val;
        return u;
    }
    inline void recon(){
        get_node(root);
        FREE(root);
        root=RECON(0,re.size()-1);
        re.clear();
    }
}ttt[50]; 
signed main() {
    cin>>n;
    int lans=0;
    while(true){
        int opt;
        cin>>opt;
        if(opt==1){
            int x,y,a;
            cin>>x>>y>>a;
            x^=lans,y^=lans,a^=lans;
            KDT T;
            int t[2]={x,y};
            T.update(t,a,T.root,0);
            for(int i=0;i<=30;i++) if(ttt[i].siz==0){
                ttt[i]=T;
                break;
            }else ttt[i].get_node(ttt[i].root),ttt[i].FREE(ttt[i].root),T.recon(),ttt[i]=KDT();
        }
        if(opt==2){
            int x,y,x2,y2;
            cin>>x>>y>>x2>>y2;
            x^=lans,y^=lans,x2^=lans,y2^=lans;
            int t1[2]={x,y},t2[2]={x2,y2},ans=0; 
            for(int i=0;i<=30;i++) ans+=ttt[i].ask(t1,t2,ttt[i].root);
            cout<<(lans=ans)<<"\n";
        }
        if(opt==3) return 0;
    }
    return 0;
}

练习题

P2479 [SDOI2010] 捉迷藏 - 洛谷

P4169 [Violet] 天使玩偶/SJY摆棋子 - 洛谷

P2093 [国家集训队] JZPFAR - 洛谷

P4390 [BalkanOI 2007] Mokia 摩基亚 - 洛谷

P4475 巧克力王国 - 洛谷

P3769 [CH弱省胡策R2] TATT - 洛谷

P5471 [NOI2019] 弹跳 - 洛谷