平衡树Treap

枫林晚

2018-04-30 23:21:37

Solution

普通平衡树 Treap ### 题目描述 您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作: 1.插入 x 数 2.删除 x 数(若有多个相同的数,因只删除一个) 3.查询 x 数的排名(排名定义为比当前数小的数的个数 +1 。若有多个相同的数,因输出最小的排名) 4.查询排名为 x 的数 5.求 x 的前驱(前驱定义为小于 x ,且最大的数) 6.求 x 的后继(后继定义为大于 x ,且最小的数) ### 输入输出格式 输入格式: 第一行为 n ,表示操作的个数,下面 n 行每行有两个数 opt 和 x , opt 表示操作的序号( 1≤opt≤6 ) 输出格式: 对于操作 3,4,5,6 每行输出一个数,表示对应答案 ## 分析: 平衡树板子题。 注意的是: 1.增加相同的数累计次数。(pushup等)(1——>t[o].sum)(WA无数) 2.注意哨兵不要被赋值(pushup) 3.查询排名,是最小的。 4.查询排名为x的数,就是第k大,注意是k-t[o].sum-t[t[o].ch[0]].size 1.初始化: 其中sum记录该节点val值有几个 ```cpp const int N=100000+10; const int inf=0x3f3f3f3f; struct node{ int ch[2]; int val,size,prio; int sum; }; void init(){ root=0;//认为root开始为0 } ``` 2.发放节点/回收节点 ```cpp int poolcur; int delpool[N],delcur; int newnode(){ int r=delcur?delpool[delcur--]:++poolcur; memset(t+r,0,sizeof(node)); t[r].size=1; t[r].prio=rand(); t[r].sum=1; return r; } void delnode(int o){ delpool[++delcur]=o; } ``` 3.pushup 注意: ①.t[t[o].ch[1]].size+t[t[o].ch[0]].size+**t[o].sum;** ②.if(!o) return; 保护哨兵(这里其实不需要了) ```cpp void pushup(int o){ if(!o) return; t[o].size=t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum; } ``` 4.rotate 注意o=u位置放最后,注意pushup()位置。 ```cpp void rotate(int &o,int d){ if(!o) return; int u=t[o].ch[d]; t[o].ch[d]=t[u].ch[d^1]; t[u].ch[d^1]=o; t[u].size=t[o].size; pushup(o); o=u; } ``` 5.insert 注意: ①值相等时,++sum与size(调半天) ②注意pushup ③根据优先级rotate ④发现当o为0时,会建边,所以开始root必须为0,而哨兵不能有问题。并且发放节点时,++poolcur而不是poolcur++,保证取到整数节点号。(调半天) ```cpp void insert(int &o,int v){ if(!o){ o=newnode(); t[o].val=v; return; } if(t[o].val==v){ t[o].sum++; t[o].size++; return; } int d=t[o].val<v; insert(t[o].ch[d],v); pushup(o); if(t[t[o].ch[d]].prio<t[o].prio) rotate(o,d); } ``` 6.remove 注意: ①先判断是否找到了删除点。保证找到了,再删除(WA一次) ②先删除一下sum,若不为0,跳过这一步。否则继续操作。注意每次判断sum是否可以再删,否则可能会在后面的递归中将sum删成负数。 ③无论如何必须要记得pushup(WA一次) ```cpp void remove(int &o,int v){ if(!o) return; if(t[o].val==v){ if(t[o].sum>0) t[o].sum--; if(t[o].sum<=0) { int u=o; if(!t[o].ch[0]){ o=t[o].ch[1]; delnode(u); } else if(!t[o].ch[1]){ o=t[o].ch[0]; delnode(u); } else{ int d=t[t[o].ch[0]].prio<t[t[o].ch[0]].prio; rotate(o , d^1); remove(t[o].ch[d],v); } } } else{ int d=t[o].val<v; remove(t[o].ch[d] , v); } pushup(o); } ``` 7.对于操作3找v的排名 注意: ①找到最小的位置,所以1+t[t[o].ch[0]].size 注意左子树都比它小(调半天) ②因为可能不存在v 所以(!o)return 0 ③往下找 +t[o].sum而不是1 ```cpp int rank(int o, int v)//找到v的排名 { if(!o) return 0; if(t[o].val>v) return rank(t[o].ch[0], v); else if(t[o].val==v) return 1+t[t[o].ch[0]].size; else return t[t[o].ch[0]].size+t[o].sum+rank(t[o].ch[1],v); } ``` 8.对于操作4找到第k大 ①想往下找必须剩下的d>t[o].sum(WA一次) ```cpp int find(int o,int k){// 找第K小的数 if(o==0||k==0) return 0; int d=k-t[t[o].ch[0]].size; if(d<= 0) return find(t[o].ch[0],k); else if(d>=1&&d<=t[o].sum) return t[o].val; else return find(t[o].ch[1],d-t[o].sum); } ``` 9.对于5/6找前去后继 ①想清楚min/max以及inf/-inf就好 ```cpp int find_front(int o,int v)//找到小于v的最大值 { if(!o) return -inf; int d= t[o].val>=v; if(!d) return max(t[o].val,find_front(t[o].ch[1],v)); else return find_front(t[o].ch[0],v); } ``` ***************** 代码纯享: ```cpp #include<bits/stdc++.h> using namespace std; const int N=100000+10; const int inf=0x3f3f3f3f; struct node{ int ch[2]; int val,size,prio; int sum; }; node t[N]; int n,root; int poolcur; int delpool[N],delcur; void init(){ root=0; } int newnode(){ int r=delcur?delpool[delcur--]:++poolcur; memset(t+r,0,sizeof(node)); t[r].size=1; t[r].prio=rand(); t[r].sum=1; return r; } void delnode(int o){ delpool[++delcur]=o; } void pushup(int o){ if(!o) return; t[o].size=t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum; } void rotate(int &o,int d){ if(!o) return; int u=t[o].ch[d]; t[o].ch[d]=t[u].ch[d^1]; t[u].ch[d^1]=o; t[u].size=t[o].size; pushup(o); o=u; } void insert(int &o,int v){ if(!o){ o=newnode(); t[o].val=v; return; } if(t[o].val==v){ t[o].sum++; t[o].size++; return; } int d=t[o].val<v; insert(t[o].ch[d],v); pushup(o); if(t[t[o].ch[d]].prio<t[o].prio) rotate(o,d); } void remove(int &o,int v){ if(!o) return; if(t[o].val==v){ if(t[o].sum>0) t[o].sum--; if(t[o].sum<=0) { int u=o; if(!t[o].ch[0]){ o=t[o].ch[1]; delnode(u); } else if(!t[o].ch[1]){ o=t[o].ch[0]; delnode(u); } else{ int d=t[t[o].ch[0]].prio<t[t[o].ch[0]].prio; rotate(o , d^1); remove(t[o].ch[d],v); } } } else{ int d=t[o].val<v; remove(t[o].ch[d] , v); } pushup(o); } int rank(int o, int v)//找到v的排名 { if(!o) return 0; if(t[o].val>v) return rank(t[o].ch[0], v); else if(t[o].val==v) return 1+t[t[o].ch[0]].size; else return t[t[o].ch[0]].size+t[o].sum+rank(t[o].ch[1],v); } int find(int o,int k){// 找第K小的数 if(o==0||k==0) return 0; int d=k-t[t[o].ch[0]].size; if(d<= 0) return find(t[o].ch[0],k); else if(d>=1&&d<=t[o].sum) return t[o].val; else return find(t[o].ch[1],d-t[o].sum); } int find_back(int o,int v)//找到大于v的最小值 { if(!o) return inf; int d= t[o].val<=v; if(!d) return min(t[o].val,find_back(t[o].ch[0],v)); else return find_back(t[o].ch[1],v); } int find_front(int o,int v)//找到小于v的最大值 { if(!o) return -inf; int d= t[o].val>=v; if(!d) return max(t[o].val,find_front(t[o].ch[1],v)); else return find_front(t[o].ch[0],v); } int ans[N]; int cnt; int main() { scanf("%d",&n); init(); int p,x; int has=0; srand((unsigned)time(NULL)); for(int i=1;i<=n;i++) { scanf("%d%d",&p,&x); if(p==1) insert(root,x),has++; if(p==2) {remove(root,x);has--;} if(p==3) ans[++cnt]=rank(root,x); if(p==4) ans[++cnt]=find(root,x); if(p==5) ans[++cnt]=find_front(root,x); if(p==6) ans[++cnt]=find_back(root,x); } for(int i=1;i<=cnt;i++) printf("%d\n",ans[i]); return 0; } ``` 注意细节,打多了就好