K-D Tree-学习笔记
i207M
2018-11-24 11:28:43
# 注意nth_element的用法:中间的位置是中位数!!!
我调这个调了好久...
## K-D Tree
划分空间,解决类似d维空间k近邻点对的问题。
## 建树方法
**KDTree是一颗二叉树!**
我们在每一个点,选择一维进行划分。选择哪一维?计算每一维的方差,选择方差最大的维。如何分治?使用nth_element将按第k维排序的中位数放在mid的位置,然后递归$[l,mid-1],[mid+1,r]$。
不过据说把d维轮流分和按方差最大的维分的复杂度是一样的。
## 查找方法
以最近邻查找为例:
到每个点的时候都更新答案。
先找到这个点的位置,然后在回溯的过程中,可能还要去另一个儿子,条件是询问点到分割平面的距离小于当前答案。
复杂度貌似是$N^{1-1/k}$。
这个是理论上更正确(记录坐标的最大最小值)的代码:
```cpp
int build(int l,int r,int d)
{
if(l>r) return 0;
int mid=(l+r)>>1;
cmpd=d;
nth_element(p+l,p+mid,p+r+1,cmp);
mx[mid][0]=mn[mid][0]=p[mid][0],mx[mid][1]=mn[mid][1]=p[mid][1];
ls[mid]=build(l,mid-1,d^1), rs[mid]=build(mid+1,r,d^1);
if(ls[mid])
{
mx[mid][0]=max(mx[ls[mid]][0],mx[mid][0]), mn[mid][0]=min(mn[ls[mid]][0],mn[mid][0]);
mx[mid][1]=max(mx[ls[mid]][1],mx[mid][1]), mn[mid][1]=min(mn[ls[mid]][1],mn[mid][1]);
}
if(rs[mid])
{
mx[mid][0]=max(mx[rs[mid]][0],mx[mid][0]), mn[mid][0]=min(mn[rs[mid]][0],mn[mid][0]);
mx[mid][1]=max(mx[rs[mid]][1],mx[mid][1]), mn[mid][1]=min(mn[rs[mid]][1],mn[mid][1]);
}
return mid;
}
double ans;
double calc(int u)
{
double x=max(0.0,q[0]-mx[u][0])+max(0.0,mn[u][0]-q[0]),y=max(0.0,q[1]-mx[u][1])+max(0.0,mn[u][1]-q[1]);
return x*x+y*y;
}
void query(int l,int r,int d)
{
if(l>r) return;
int mid=(l+r)>>1;
if(qid!=mid) ans=min(ans,dis(q,p[mid]));
double dl=calc(ls[mid]),dr=calc(rs[mid]);
if(q[d]<p[mid][d])
{
if(dl<ans) query(l,mid-1,d^1);
if(dr<ans) query(mid+1,r,d^1);
}
else
{
if(dr<ans) query(mid+1,r,d^1);
if(dl<ans) query(l,mid-1,d^1);
}
}
```
------------
这个是更好写的代码:
```cpp
bool cmp(Node &a,Node &b)
{
return a[cmpd]<b[cmpd];
}
void build(int l,int r)
{
if(l>r) return;
int mid=(l+r)>>1;
// calc variance
double c[2]= {0.0,0.0};
for(ri d=0; d<=1; ++d)
{
double ave=0.0;
for(ri i=l; i<=r; ++i) ave+=p[i][d];
ave/=(r-l+1);
for(ri i=l; i<=r; ++i) c[d]+=(p[i][d]-ave)*(p[i][d]-ave);
}
cmpd=spl[mid]=(c[0]<c[1]);
nth_element(p+l,p+mid,p+r+1,cmp);
build(l,mid-1), build(mid+1,r);
}
double ans;
void query(int l,int r)
{
if(l>r) return;
int mid=(l+r)>>1,d=spl[mid];
if(qid!=mid) ans=min(ans,dis(q,p[mid]));
double rad=(q[d]-p[mid][d])*(q[d]-p[mid][d]);
if(q[d]<p[mid][d])
{
query(l,mid-1);
if(rad<ans) query(mid+1,r);
}
else
{
query(mid+1,r);
if(rad<ans) query(l,mid-1);
}
}
```
---------
## K-D Tree 区间操作的代码
BZOJ 简单题
```
namespace KD
{
const double alp=0.7;
const int inf=1e9;
struct Node
{
int d[2];
Node() {}
Node(const int _x,const int _y)
{
d[0]=_x,d[1]=_y;
}
int &operator[](const int x)
{
return d[x];
}
const int operator[](const int x) const
{
return d[x];
}
friend bool operator==(const Node &a,const Node &b)
{
return a[0]==b[0]&&a[1]==b[1];
}
} pt[N];
int cmpd;
bool cmp(const Node &a,const Node &b)
{
return a[cmpd]<b[cmpd];
}
il bool cmp2(const pair<Node,int> &a,const pair<Node,int> &b)
{
return cmp(a.fi,b.fi);
}
int tot,ls[N],rs[N],val[N],sum[N],sz[N];
int mx[N][2],mn[N][2];
il void up(int x)
{
sum[x]=sum[ls[x]]+val[x]+sum[rs[x]];
sz[x]=sz[ls[x]]+1+sz[rs[x]];
mx[x][0]=max(max(mx[ls[x]][0],mx[rs[x]][0]),pt[x][0]);
mn[x][0]=min(min(ls[x]?mn[ls[x]][0]:inf,rs[x]?mn[rs[x]][0]:inf),pt[x][0]);
mx[x][1]=max(max(mx[ls[x]][1],mx[rs[x]][1]),pt[x][1]);
mn[x][1]=min(min(ls[x]?mn[ls[x]][1]:inf,rs[x]?mn[rs[x]][1]:inf),pt[x][1]);
}
int tp;
pair<Node,int> st[N];
int cur[N];
void dfs(int x)
{
if(!x) return;
dfs(ls[x]);
st[++tp]=mp(pt[x],val[x]),cur[tp]=x;
dfs(rs[x]);
}
void setup(int &x,int d,int l,int r)
{
if(l>r)
{
x=0;
return;
}
gm;
cmpd=d; nth_element(st+l,st+mid,st+r+1,cmp2);
x=cur[mid]; pt[x]=st[mid].fi,val[x]=st[mid].se;
setup(ls[x],d^1,l,mid-1),setup(rs[x],d^1,mid+1,r);
up(x);
}
void rebuild(int &x,int d)
{
tp=0; dfs(x);
setup(x,d,1,tp);
}
int *ned,nd;
void ins(int &x,int d,const Node &p,int v)
{
if(!x)
{
x=++tot; pt[x]=p,val[x]=v;
up(x);
return;
}
if(p==pt[x])
{
val[x]+=v;
up(x);
return;
}
if(p[d]<pt[x][d]) ins(ls[x],d^1,p,v);
else ins(rs[x],d^1,p,v);
up(x);
if(max(sz[rs[x]],sz[ls[x]])>sz[x]*alp) ned=&x,nd=d;
}
int query(int x,int lx,int ux,int ly,int uy)
{
if(!x||lx>mx[x][0]||ux<mn[x][0]||ly>mx[x][1]||uy<mn[x][1]) return 0;
if(lx<=mn[x][0]&&mx[x][0]<=ux&&ly<=mn[x][1]&&mx[x][1]<=uy) return sum[x];
return query(ls[x],lx,ux,ly,uy)+(lx<=pt[x][0]&&pt[x][0]<=ux&&ly<=pt[x][1]&&pt[x][1]<=uy?val[x]:0)+query(rs[x],lx,ux,ly,uy);
}
int rt;
il void Upd(int x,int y,int k)
{
ned=NULL;
ins(rt,0,Node(x,y),k);
if(ned!=NULL) rebuild(*ned,nd);
}
il int Query(int lx,int ux,int ly,int uy)
{
return query(rt,lx,ux,ly,uy);
}
}
```