K-D Tree-学习笔记

i207M

2018-11-24 11:28:43

Personal

# 注意nth_element的用法:中间的位置是中位数!!! 我调这个调了好久... ## K-D Tree 划分空间,解决类似d维空间k近邻点对的问题。 ## 建树方法 **KDTree是一颗二叉树!** 我们在每一个点,选择一维进行划分。选择哪一维?计算每一维的方差,选择方差最大的维。如何分治?使用nth_element将按第k维排序的中位数放在mid的位置,然后递归$[l,mid-1],[mid+1,r]$。 不过据说把d维轮流分和按方差最大的维分的复杂度是一样的。 ## 查找方法 以最近邻查找为例: 到每个点的时候都更新答案。 先找到这个点的位置,然后在回溯的过程中,可能还要去另一个儿子,条件是询问点到分割平面的距离小于当前答案。 复杂度貌似是$N^{1-1/k}$。 这个是理论上更正确(记录坐标的最大最小值)的代码: ```cpp int build(int l,int r,int d) { if(l>r) return 0; int mid=(l+r)>>1; cmpd=d; nth_element(p+l,p+mid,p+r+1,cmp); mx[mid][0]=mn[mid][0]=p[mid][0],mx[mid][1]=mn[mid][1]=p[mid][1]; ls[mid]=build(l,mid-1,d^1), rs[mid]=build(mid+1,r,d^1); if(ls[mid]) { mx[mid][0]=max(mx[ls[mid]][0],mx[mid][0]), mn[mid][0]=min(mn[ls[mid]][0],mn[mid][0]); mx[mid][1]=max(mx[ls[mid]][1],mx[mid][1]), mn[mid][1]=min(mn[ls[mid]][1],mn[mid][1]); } if(rs[mid]) { mx[mid][0]=max(mx[rs[mid]][0],mx[mid][0]), mn[mid][0]=min(mn[rs[mid]][0],mn[mid][0]); mx[mid][1]=max(mx[rs[mid]][1],mx[mid][1]), mn[mid][1]=min(mn[rs[mid]][1],mn[mid][1]); } return mid; } double ans; double calc(int u) { double x=max(0.0,q[0]-mx[u][0])+max(0.0,mn[u][0]-q[0]),y=max(0.0,q[1]-mx[u][1])+max(0.0,mn[u][1]-q[1]); return x*x+y*y; } void query(int l,int r,int d) { if(l>r) return; int mid=(l+r)>>1; if(qid!=mid) ans=min(ans,dis(q,p[mid])); double dl=calc(ls[mid]),dr=calc(rs[mid]); if(q[d]<p[mid][d]) { if(dl<ans) query(l,mid-1,d^1); if(dr<ans) query(mid+1,r,d^1); } else { if(dr<ans) query(mid+1,r,d^1); if(dl<ans) query(l,mid-1,d^1); } } ``` ------------ 这个是更好写的代码: ```cpp bool cmp(Node &a,Node &b) { return a[cmpd]<b[cmpd]; } void build(int l,int r) { if(l>r) return; int mid=(l+r)>>1; // calc variance double c[2]= {0.0,0.0}; for(ri d=0; d<=1; ++d) { double ave=0.0; for(ri i=l; i<=r; ++i) ave+=p[i][d]; ave/=(r-l+1); for(ri i=l; i<=r; ++i) c[d]+=(p[i][d]-ave)*(p[i][d]-ave); } cmpd=spl[mid]=(c[0]<c[1]); nth_element(p+l,p+mid,p+r+1,cmp); build(l,mid-1), build(mid+1,r); } double ans; void query(int l,int r) { if(l>r) return; int mid=(l+r)>>1,d=spl[mid]; if(qid!=mid) ans=min(ans,dis(q,p[mid])); double rad=(q[d]-p[mid][d])*(q[d]-p[mid][d]); if(q[d]<p[mid][d]) { query(l,mid-1); if(rad<ans) query(mid+1,r); } else { query(mid+1,r); if(rad<ans) query(l,mid-1); } } ``` --------- ## K-D Tree 区间操作的代码 BZOJ 简单题 ``` namespace KD { const double alp=0.7; const int inf=1e9; struct Node { int d[2]; Node() {} Node(const int _x,const int _y) { d[0]=_x,d[1]=_y; } int &operator[](const int x) { return d[x]; } const int operator[](const int x) const { return d[x]; } friend bool operator==(const Node &a,const Node &b) { return a[0]==b[0]&&a[1]==b[1]; } } pt[N]; int cmpd; bool cmp(const Node &a,const Node &b) { return a[cmpd]<b[cmpd]; } il bool cmp2(const pair<Node,int> &a,const pair<Node,int> &b) { return cmp(a.fi,b.fi); } int tot,ls[N],rs[N],val[N],sum[N],sz[N]; int mx[N][2],mn[N][2]; il void up(int x) { sum[x]=sum[ls[x]]+val[x]+sum[rs[x]]; sz[x]=sz[ls[x]]+1+sz[rs[x]]; mx[x][0]=max(max(mx[ls[x]][0],mx[rs[x]][0]),pt[x][0]); mn[x][0]=min(min(ls[x]?mn[ls[x]][0]:inf,rs[x]?mn[rs[x]][0]:inf),pt[x][0]); mx[x][1]=max(max(mx[ls[x]][1],mx[rs[x]][1]),pt[x][1]); mn[x][1]=min(min(ls[x]?mn[ls[x]][1]:inf,rs[x]?mn[rs[x]][1]:inf),pt[x][1]); } int tp; pair<Node,int> st[N]; int cur[N]; void dfs(int x) { if(!x) return; dfs(ls[x]); st[++tp]=mp(pt[x],val[x]),cur[tp]=x; dfs(rs[x]); } void setup(int &x,int d,int l,int r) { if(l>r) { x=0; return; } gm; cmpd=d; nth_element(st+l,st+mid,st+r+1,cmp2); x=cur[mid]; pt[x]=st[mid].fi,val[x]=st[mid].se; setup(ls[x],d^1,l,mid-1),setup(rs[x],d^1,mid+1,r); up(x); } void rebuild(int &x,int d) { tp=0; dfs(x); setup(x,d,1,tp); } int *ned,nd; void ins(int &x,int d,const Node &p,int v) { if(!x) { x=++tot; pt[x]=p,val[x]=v; up(x); return; } if(p==pt[x]) { val[x]+=v; up(x); return; } if(p[d]<pt[x][d]) ins(ls[x],d^1,p,v); else ins(rs[x],d^1,p,v); up(x); if(max(sz[rs[x]],sz[ls[x]])>sz[x]*alp) ned=&x,nd=d; } int query(int x,int lx,int ux,int ly,int uy) { if(!x||lx>mx[x][0]||ux<mn[x][0]||ly>mx[x][1]||uy<mn[x][1]) return 0; if(lx<=mn[x][0]&&mx[x][0]<=ux&&ly<=mn[x][1]&&mx[x][1]<=uy) return sum[x]; return query(ls[x],lx,ux,ly,uy)+(lx<=pt[x][0]&&pt[x][0]<=ux&&ly<=pt[x][1]&&pt[x][1]<=uy?val[x]:0)+query(rs[x],lx,ux,ly,uy); } int rt; il void Upd(int x,int y,int k) { ned=NULL; ins(rt,0,Node(x,y),k); if(ned!=NULL) rebuild(*ned,nd); } il int Query(int lx,int ux,int ly,int uy) { return query(rt,lx,ux,ly,uy); } } ```