题解:P14312 【模板】K-D Tree

· · 题解

发现还能写题解来写一篇
K-D Tree 是一种可以高效处理 k 维空间里点的信息,它相较于其他的数据结构,空间小、可以处理在线信息,是一种绝妙的骗分方式。但是 k 一般是 23,如果更大它的复杂度甚至还没有暴力的复杂度好。

1. 建树

K-D Tree 一般使用交叉建树和方差建树两种,一般使用交叉建树。交叉建树的建树方式是:

  1. 选择一个维度;
  2. 求出这个维度中点的中位数,将其作为目前的根节点;
  3. 按中位数把点分成两部分,并选择下一个维度,继续递归建树。

k=2 为例,给定点集 (5,5),(7,1),(3,6),(2,1),(6,3),(8,6),(9,2),则建树如下:这样建树树的深度就是严格 \log n 的。
现在的问题便是如何求中位数。这时我们注意到,在 algorithm 中提供了函数 nth_element(s+l,s+mid,s+r+1,cmp),可以在 O(n) 的时间复杂度内求出 s_ls_r 的区间内在 cmp 的排序规则下 s_{mid} 的值,这样静态建树的复杂度就变成了 O(n\log n)。代码如下:

int build(int l,int r,int type){
    if(l>r)return 0;
    int mid=l+r>>1;
    nth_element(tem+l,tem+mid,tem+r+1,[&](int a,int b){
        return t[a].x[type]<t[b].x[type];
    });
    int x=tem[mid];
    t[x].ls=build(l,mid-1,(type+1)%k);
    t[x].rs=build(mid+1,r,(type+1)%k);
    pushup(x);
    return x;
}

2. 查询

为了查询 k 维空间内的矩形区域中的所有点的信息,我们可以维护每一个维度上点的最大值和最小值,这样就可以快速判断查询矩形是否与当前子树有交叉、完全覆盖当前子树所有节点没有交点、没有交叉。这样查询的复杂度便为 O(n^{1-\frac{1}{k}})
对于这三种情况,我们可以写出以下代码:

int query(int p){
    if(p==0)return 0;
    pushdown(p);
    for(int i=0;i<=k-1;i++)if(t[p].mi[i]>R[i]||t[p].ma[i]<L[i])return 0;//没有交叉直接返回
    bool flag=true;
    for(int i=0;i<=k-1;i++){
        if(L[i]>t[p].mi[i]||t[p].ma[i]>R[i]){
            flag=false;
            break;
        }
    }
    if(flag)return t[p].sum;//被完全覆盖
    int res=0;
    flag=true;
    for(int i=0;i<=k-1;i++){
        if(t[p].x[i]>R[i]||t[p].x[i]<L[i]){
            flag=false;
            break;
        }
    }
    if(flag)res=t[p].v;
    return res+query(t[p].ls)+query(t[p].rs);//有交叉继续递归左右子树
}

3. 动态建树

对于需要中途添加或删除点的情况,我们可以使用二进制分组保持树的平衡,并仍然能保证时间复杂度。
二进制分组,就是建立若干棵 K-D Tree,第 i 棵存 2^i,它们的大小之和为 n。可以证明只需要 \log n+1 棵树。
对于加入一个点的操作,我们可以把这个点看做一棵大小为 1 的树,将其加入树中。此时如果树中有大小为 1 的树,则将两棵树合并为大小为 2 的树;这时如果树中有大小为 2 的树,则将两棵树合并为大小为 4 的树,。以此类推,直至不冲突。
因为个点每参与一次合并,其所在的树大小就会翻倍。因为最多只有 O(\log n) 个大小级别,所以每个点最多会被重构 O(\log n) 次。单次插入的均摊时间复杂度是 O(\log^2n)。查询时,只需对当前树中非空的每一棵树都进行一次查询,最后将所有结果汇总即可。因为有 O(\log n) 棵树,查询的总体时间复杂度不变,依然是 O(\sqrt{n}) 的复杂度。
对于本题,因为有区间查询,我们只需要添加一个 lazy 标记即可。
submission

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
struct node{
    int x[3],v,sum,ls,rs,siz,mi[3],ma[3],lazy;
}t[maxn];
int n,lst,rt[20],tem[maxn],L[3],R[3],k,cnt,val,m;
void pushdown(int p){
    if(t[p].lazy){
        if(t[p].ls){
            t[t[p].ls].sum+=t[p].lazy*t[t[p].ls].siz;
            t[t[p].ls].v+=t[p].lazy;
            t[t[p].ls].lazy+=t[p].lazy;
        }
        if(t[p].rs){
            t[t[p].rs].sum+=t[p].lazy*t[t[p].rs].siz;
            t[t[p].rs].v+=t[p].lazy;
            t[t[p].rs].lazy+=t[p].lazy;
        }
        t[p].lazy=0;
    }
}
void pushup(int p){
    t[p].siz=1+t[t[p].ls].siz+t[t[p].rs].siz;
    t[p].sum=t[t[p].ls].sum+t[p].v+t[t[p].rs].sum;
    for(int i=0;i<=k-1;i++){
        t[p].mi[i]=t[p].ma[i]=t[p].x[i];
        if(t[p].ls){
            t[p].mi[i]=min(t[p].mi[i],t[t[p].ls].mi[i]);
            t[p].ma[i]=max(t[p].ma[i],t[t[p].ls].ma[i]);
        }
        if(t[p].rs){
            t[p].mi[i]=min(t[p].mi[i],t[t[p].rs].mi[i]);
            t[p].ma[i]=max(t[p].ma[i],t[t[p].rs].ma[i]);
        }
    }
}
int build(int l,int r,int type){
    if(l>r)return 0;
    int mid=l+r>>1;
    nth_element(tem+l,tem+mid,tem+r+1,[&](int a,int b){
        return t[a].x[type]<t[b].x[type];
    });
    int x=tem[mid];
    t[x].ls=build(l,mid-1,(type+1)%k);
    t[x].rs=build(mid+1,r,(type+1)%k);
    pushup(x);
    return x;
}
void append(int &p){
    if(p==0)return;
    pushdown(p);
    tem[++cnt]=p;
    append(t[p].ls);
    append(t[p].rs);
    p=0;
}
void add(int p){
    if(p==0)return;
    pushdown(p);
    for(int i=0;i<=k-1;i++)if(t[p].mi[i]>R[i]||t[p].ma[i]<L[i])return;
    bool flag=true;
    for(int i=0;i<=k-1;i++){
        if(t[p].mi[i]<L[i]||t[p].ma[i]>R[i]){
            flag=false;
            break;
        }
    }
    if(flag){
        t[p].sum+=t[p].siz*val,t[p].v+=val,t[p].lazy+=val;
        return;
    }
    flag=true;
    for(int i=0;i<=k-1;i++){
        if(t[p].x[i]>R[i]||t[p].x[i]<L[i]){
            flag=false;
            break;
        }
    }
    if(flag)t[p].v+=val;
    add(t[p].ls);
    add(t[p].rs);
    pushup(p);
}
int query(int p){
    if(p==0)return 0;
    pushdown(p);
    for(int i=0;i<=k-1;i++)if(t[p].mi[i]>R[i]||t[p].ma[i]<L[i])return 0;
    bool flag=true;
    for(int i=0;i<=k-1;i++){
        if(L[i]>t[p].mi[i]||t[p].ma[i]>R[i]){
            flag=false;
            break;
        }
    }
    if(flag)return t[p].sum;
    int res=0;
    flag=true;
    for(int i=0;i<=k-1;i++){
        if(t[p].x[i]>R[i]||t[p].x[i]<L[i]){
            flag=false;
            break;
        }
    }
    if(flag)res=t[p].v;
    return res+query(t[p].ls)+query(t[p].rs);
}
signed main(){
    cin>>k>>m;
    n=lst=0;
    while(m--){
        int op;
        cin>>op;
        if(op==1){
            n++;
            for(int i=0;i<=k-1;i++){
                cin>>t[n].x[i];
                t[n].x[i]^=lst;
            }
            cin>>t[n].v;
            t[n].v^=lst;
            tem[cnt=1]=n;
            for(int i=1;;i++){
                if(rt[i]==0){
                    rt[i]=build(1,cnt,0);
                    break;
                }else append(rt[i]);
            }
        }else if(op==2){
            for(int i=0;i<=k-1;i++){
                cin>>L[i];
                L[i]^=lst;
            }
            for(int i=0;i<=k-1;i++){
                cin>>R[i];
                R[i]^=lst;
            }
            cin>>val;
            val^=lst;
            for(int i=1;i<=18;i++)add(rt[i]);
        }else if(op==3){
            for(int i=0;i<=k-1;i++){
                cin>>L[i];
                L[i]^=lst;
            }
            for(int i=0;i<=k-1;i++){
                cin>>R[i];
                R[i]^=lst;
            }
            lst=0;
            for(int i=1;i<=18;i++)lst+=query(rt[i]);
            cout<<lst<<"\n";
        }
    }
}