Kd-Tree

· · 个人记录

根据方差进行维度划分

宏定义

变量

函数

代码

const int DE=4;//3维 
int nde,L[DE],R[DE];
struct Node{
    int w[DE],val;
}s[N];
bool cmp(int p,int q){
    return s[p].w[nde]==s[q].w[nde]?p<q:s[p].w[nde]<s[q].w[nde];
}
#define ls(p) a[p].tl
#define rs(p) a[p].tr
struct KD{
    const double A=0.725;
    int rt,tot,g[N],ans;
    struct Tree{
        int tl,tr,de,sz;
        int l[DE],r[DE],mx;
    }a[N];
    void dfs(int p){
        if(ls(p))dfs(ls(p));
        g[++tot]=p;
        if(rs(p))dfs(rs(p));
    }
    void update(int p,int son,int de){
        a[p].l[de]=min(a[p].l[de],a[son].l[de]);
        a[p].r[de]=max(a[p].r[de],a[son].r[de]);
    }
    void pushup(int p){
        a[p].sz=a[ls(p)].sz+a[rs(p)].sz+1;
        a[p].mx=max(s[p].val,max(a[ls(p)].mx,a[rs(p)].mx));
        for(int i=1;i<DE;i++)a[p].l[i]=a[p].r[i]=s[p].w[i];
        if(ls(p))for(int i=1;i<DE;i++)update(p,ls(p),i);
        if(rs(p))for(int i=1;i<DE;i++)update(p,rs(p),i);
    }
    double fangc(int l,int r,int de){
        double av=0,sa=0;
        for(int i=l;i<=r;i++)av+=s[g[i]].w[de];
        av/=double(r-l+1);
        for(int i=l;i<=r;i++)
            sa+=(av-s[g[i]].w[de])*(av-s[g[i]].w[de]);
        return sa;
    }
    int build(int l,int r){
        if(l>r)return 0;
        int p=(l+r)>>1;
        double fcm=0,fc;
        nde=0;
        for(int i=1;i<DE;i++){
            fc=fangc(l,r,i);
            if(!nde||fcm<fc)
                nde=i,fcm=fc;
        }
        nth_element(g+l,g+p,g+r+1,cmp);
        a[g[p]].de=nde;
        ls(g[p])=build(l,p-1);rs(g[p])=build(p+1,r);
        pushup(g[p]);return g[p];
    }
    void rebuild(int &p){
        tot=0;dfs(p);p=build(1,tot);
    } 
    bool bad(int p){
        return a[p].sz*A<=(double)max(a[ls(p)].sz,a[rs(p)].sz);
    }
    void insert(int &p,int x){
        if(!p){
            pushup(p=x);return;
        }
        if(s[x].w[a[p].de]<=s[p].w[a[p].de])
            insert(ls(p),x);
        else insert(rs(p),x);
        pushup(p);
        if(bad(p))rebuild(p);
    }
    void query(int p,int *l,int *r){
        if(!p)return;
        bool flag1=1,flag2=1;
        for(int i=1;i<DE;i++){
            if(r[i]<a[p].l[i]||a[p].r[i]<l[i])return;
            if(r[i]<a[p].r[i]||a[p].l[i]<l[i])flag1=0;
            if(r[i]<s[p].w[i]||s[p].w[i]<l[i])flag2=0;
        }
        if(flag1){
            ans=max(ans,a[p].mx);
            return;
        }
        ans=max(ans,flag2*s[p].val);
        if(ans<a[ls(p)].mx)
            query(ls(p),l,r);
        if(ans<a[rs(p)].mx)
            query(rs(p),l,r);
        return;
    }
    void change(int p,int x,int val){
        if(p==x){
            s[p].val=val;
            a[p].mx=max(a[p].mx,s[p].val);return;
        }
        nde=a[p].de; 
        if(cmp(x,p))
            change(ls(p),x,val);
        else change(rs(p),x,val);
        a[p].mx=max(a[p].mx,max(a[ls(p)].mx,a[rs(p)].mx));
    }
}kd;

根据深度进行维度划分

宏定义

变量

函数

代码

const int DE=4;//3维 
int nde,L[DE],R[DE];
struct Node{
    int w[DE],val;
}s[N];
bool cmp(int p,int q){
    return s[p].w[nde]==s[q].w[nde]?p<q:s[p].w[nde]<s[q].w[nde];
}
#define ls(p) a[p].tl
#define rs(p) a[p].tr
struct KD{
    const double A=0.725;
    int rt,tot,g[N],ans;
    struct Tree{
        int tl,tr,de,sz;
        int l[DE],r[DE],mx;
    }a[N];
    void dfs(int p){
        if(ls(p))dfs(ls(p));
        g[++tot]=p;
        if(rs(p))dfs(rs(p));
    }
    void update(int p,int son,int de){
        a[p].l[de]=min(a[p].l[de],a[son].l[de]);
        a[p].r[de]=max(a[p].r[de],a[son].r[de]);
    }
    void pushup(int p){
        a[p].sz=a[ls(p)].sz+a[rs(p)].sz+1;
        a[p].mx=max(s[p].val,max(a[ls(p)].mx,a[rs(p)].mx));
        for(int i=1;i<DE;i++)a[p].l[i]=a[p].r[i]=s[p].w[i];
        if(ls(p))for(int i=1;i<DE;i++)update(p,ls(p),i);
        if(rs(p))for(int i=1;i<DE;i++)update(p,rs(p),i);
    }
    int build(int lde,int l,int r){
        if(l>r)return 0;
        int p=(l+r)>>1;
        nde=(lde==3?1:lde+1);
        nth_element(g+l,g+p,g+r+1,cmp);
        a[g[p]].de=nde;
        ls(g[p])=build(a[g[p]].de,l,p-1);
        rs(g[p])=build(a[g[p]].de,p+1,r);
        pushup(g[p]);return g[p];
    }
    void rebuild(int &p){
        tot=0;dfs(p);p=build(1,1,tot);
    } 
    bool bad(int p){
        return a[p].sz*A<=(double)max(a[ls(p)].sz,a[rs(p)].sz);
    }
    void insert(int &p,int x){
        if(!p){
            pushup(p=x);return;
        }
        if(s[x].w[a[p].de]<=s[p].w[a[p].de])
            insert(ls(p),x);
        else insert(rs(p),x);
        pushup(p);
        if(bad(p))rebuild(p);
    }
    void query(int p,int *l,int *r){
        if(!p)return;
        bool flag1=1,flag2=1;
        for(int i=1;i<DE;i++){
            if(r[i]<a[p].l[i]||a[p].r[i]<l[i])return;
            if(r[i]<a[p].r[i]||a[p].l[i]<l[i])flag1=0;
            if(r[i]<s[p].w[i]||s[p].w[i]<l[i])flag2=0;
        }
        if(flag1){
            ans=max(ans,a[p].mx);
            return;
        }
        ans=max(ans,flag2*s[p].val);
        if(ans<a[ls(p)].mx)
            query(ls(p),l,r);
        if(ans<a[rs(p)].mx)
            query(rs(p),l,r);
        return;
    }
    void change(int p,int x,int val){
        if(p==x){
            s[p].val=val;
            a[p].mx=max(a[p].mx,s[p].val);return;
        }
        nde=a[p].de;
        if(cmp(x,p))
            change(ls(p),x,val);
        else change(rs(p),x,val);
        a[p].mx=max(a[p].mx,max(a[ls(p)].mx,a[rs(p)].mx));
    }
}kd;