平衡树Treap
枫林晚
2018-04-30 23:21:37
普通平衡树 Treap
### 题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1.插入 x 数
2.删除 x 数(若有多个相同的数,因只删除一个)
3.查询 x 数的排名(排名定义为比当前数小的数的个数 +1 。若有多个相同的数,因输出最小的排名)
4.查询排名为 x 的数
5.求 x 的前驱(前驱定义为小于 x ,且最大的数)
6.求 x 的后继(后继定义为大于 x ,且最小的数)
### 输入输出格式
输入格式:
第一行为 n ,表示操作的个数,下面 n 行每行有两个数 opt 和 x , opt 表示操作的序号( 1≤opt≤6 )
输出格式:
对于操作 3,4,5,6 每行输出一个数,表示对应答案
## 分析:
平衡树板子题。
注意的是:
1.增加相同的数累计次数。(pushup等)(1——>t[o].sum)(WA无数)
2.注意哨兵不要被赋值(pushup)
3.查询排名,是最小的。
4.查询排名为x的数,就是第k大,注意是k-t[o].sum-t[t[o].ch[0]].size
1.初始化:
其中sum记录该节点val值有几个
```cpp
const int N=100000+10;
const int inf=0x3f3f3f3f;
struct node{
int ch[2];
int val,size,prio;
int sum;
};
void init(){
root=0;//认为root开始为0
}
```
2.发放节点/回收节点
```cpp
int poolcur;
int delpool[N],delcur;
int newnode(){
int r=delcur?delpool[delcur--]:++poolcur;
memset(t+r,0,sizeof(node));
t[r].size=1;
t[r].prio=rand();
t[r].sum=1;
return r;
}
void delnode(int o){
delpool[++delcur]=o;
}
```
3.pushup
注意:
①.t[t[o].ch[1]].size+t[t[o].ch[0]].size+**t[o].sum;**
②.if(!o) return; 保护哨兵(这里其实不需要了)
```cpp
void pushup(int o){
if(!o) return;
t[o].size=t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum;
}
```
4.rotate
注意o=u位置放最后,注意pushup()位置。
```cpp
void rotate(int &o,int d){
if(!o) return;
int u=t[o].ch[d];
t[o].ch[d]=t[u].ch[d^1];
t[u].ch[d^1]=o;
t[u].size=t[o].size;
pushup(o);
o=u;
}
```
5.insert
注意:
①值相等时,++sum与size(调半天)
②注意pushup
③根据优先级rotate
④发现当o为0时,会建边,所以开始root必须为0,而哨兵不能有问题。并且发放节点时,++poolcur而不是poolcur++,保证取到整数节点号。(调半天)
```cpp
void insert(int &o,int v){
if(!o){
o=newnode();
t[o].val=v;
return;
}
if(t[o].val==v){
t[o].sum++;
t[o].size++;
return;
}
int d=t[o].val<v;
insert(t[o].ch[d],v);
pushup(o);
if(t[t[o].ch[d]].prio<t[o].prio) rotate(o,d);
}
```
6.remove
注意:
①先判断是否找到了删除点。保证找到了,再删除(WA一次)
②先删除一下sum,若不为0,跳过这一步。否则继续操作。注意每次判断sum是否可以再删,否则可能会在后面的递归中将sum删成负数。
③无论如何必须要记得pushup(WA一次)
```cpp
void remove(int &o,int v){
if(!o) return;
if(t[o].val==v){
if(t[o].sum>0) t[o].sum--;
if(t[o].sum<=0)
{
int u=o;
if(!t[o].ch[0]){
o=t[o].ch[1];
delnode(u);
}
else if(!t[o].ch[1]){
o=t[o].ch[0];
delnode(u);
}
else{
int d=t[t[o].ch[0]].prio<t[t[o].ch[0]].prio;
rotate(o , d^1);
remove(t[o].ch[d],v);
}
}
}
else{
int d=t[o].val<v;
remove(t[o].ch[d] , v);
}
pushup(o);
}
```
7.对于操作3找v的排名
注意:
①找到最小的位置,所以1+t[t[o].ch[0]].size
注意左子树都比它小(调半天)
②因为可能不存在v 所以(!o)return 0
③往下找 +t[o].sum而不是1
```cpp
int rank(int o, int v)//找到v的排名
{
if(!o) return 0;
if(t[o].val>v) return rank(t[o].ch[0], v);
else if(t[o].val==v) return 1+t[t[o].ch[0]].size;
else return t[t[o].ch[0]].size+t[o].sum+rank(t[o].ch[1],v);
}
```
8.对于操作4找到第k大
①想往下找必须剩下的d>t[o].sum(WA一次)
```cpp
int find(int o,int k){// 找第K小的数
if(o==0||k==0) return 0;
int d=k-t[t[o].ch[0]].size;
if(d<= 0) return find(t[o].ch[0],k);
else if(d>=1&&d<=t[o].sum) return t[o].val;
else return find(t[o].ch[1],d-t[o].sum);
}
```
9.对于5/6找前去后继
①想清楚min/max以及inf/-inf就好
```cpp
int find_front(int o,int v)//找到小于v的最大值
{
if(!o) return -inf;
int d= t[o].val>=v;
if(!d) return max(t[o].val,find_front(t[o].ch[1],v));
else return find_front(t[o].ch[0],v);
}
```
*****************
代码纯享:
```cpp
#include<bits/stdc++.h>
using namespace std;
const int N=100000+10;
const int inf=0x3f3f3f3f;
struct node{
int ch[2];
int val,size,prio;
int sum;
};
node t[N];
int n,root;
int poolcur;
int delpool[N],delcur;
void init(){
root=0;
}
int newnode(){
int r=delcur?delpool[delcur--]:++poolcur;
memset(t+r,0,sizeof(node));
t[r].size=1;
t[r].prio=rand();
t[r].sum=1;
return r;
}
void delnode(int o){
delpool[++delcur]=o;
}
void pushup(int o){
if(!o) return;
t[o].size=t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum;
}
void rotate(int &o,int d){
if(!o) return;
int u=t[o].ch[d];
t[o].ch[d]=t[u].ch[d^1];
t[u].ch[d^1]=o;
t[u].size=t[o].size;
pushup(o);
o=u;
}
void insert(int &o,int v){
if(!o){
o=newnode();
t[o].val=v;
return;
}
if(t[o].val==v){
t[o].sum++;
t[o].size++;
return;
}
int d=t[o].val<v;
insert(t[o].ch[d],v);
pushup(o);
if(t[t[o].ch[d]].prio<t[o].prio) rotate(o,d);
}
void remove(int &o,int v){
if(!o) return;
if(t[o].val==v){
if(t[o].sum>0) t[o].sum--;
if(t[o].sum<=0)
{
int u=o;
if(!t[o].ch[0]){
o=t[o].ch[1];
delnode(u);
}
else if(!t[o].ch[1]){
o=t[o].ch[0];
delnode(u);
}
else{
int d=t[t[o].ch[0]].prio<t[t[o].ch[0]].prio;
rotate(o , d^1);
remove(t[o].ch[d],v);
}
}
}
else{
int d=t[o].val<v;
remove(t[o].ch[d] , v);
}
pushup(o);
}
int rank(int o, int v)//找到v的排名
{
if(!o) return 0;
if(t[o].val>v) return rank(t[o].ch[0], v);
else if(t[o].val==v) return 1+t[t[o].ch[0]].size;
else return t[t[o].ch[0]].size+t[o].sum+rank(t[o].ch[1],v);
}
int find(int o,int k){// 找第K小的数
if(o==0||k==0) return 0;
int d=k-t[t[o].ch[0]].size;
if(d<= 0) return find(t[o].ch[0],k);
else if(d>=1&&d<=t[o].sum) return t[o].val;
else return find(t[o].ch[1],d-t[o].sum);
}
int find_back(int o,int v)//找到大于v的最小值
{
if(!o) return inf;
int d= t[o].val<=v;
if(!d) return min(t[o].val,find_back(t[o].ch[0],v));
else return find_back(t[o].ch[1],v);
}
int find_front(int o,int v)//找到小于v的最大值
{
if(!o) return -inf;
int d= t[o].val>=v;
if(!d) return max(t[o].val,find_front(t[o].ch[1],v));
else return find_front(t[o].ch[0],v);
}
int ans[N];
int cnt;
int main()
{
scanf("%d",&n);
init();
int p,x;
int has=0;
srand((unsigned)time(NULL));
for(int i=1;i<=n;i++)
{
scanf("%d%d",&p,&x);
if(p==1) insert(root,x),has++;
if(p==2) {remove(root,x);has--;}
if(p==3) ans[++cnt]=rank(root,x);
if(p==4) ans[++cnt]=find(root,x);
if(p==5) ans[++cnt]=find_front(root,x);
if(p==6) ans[++cnt]=find_back(root,x);
}
for(int i=1;i<=cnt;i++)
printf("%d\n",ans[i]);
return 0;
}
```
注意细节,打多了就好