替罪羊树
替罪羊树
替罪羊树是平衡树中较为简单的一种数据结构。
首先来看模板
洛谷P3369普通平衡树
相信大家都学过二叉搜索树,但是显然凉心出题人会卡BST不过BST在随机不畸形数据里速度确实吊打一票平衡树
BST由于很容易退化成链,所以我们必须想个办法不让他退化成链,必须要用一种方法让他看上去长得尽量平衡而又不破坏性质。
因此人们就想出了各种各样奇奇怪怪的方法来维护,所以才出现了这么多种类的平衡树。
平衡树按照平衡的方法一般分为两大类:旋转平衡树和重量平衡树
旋转平衡树就是基于旋转节点换根来维护平衡,splay,spaly,treap,红黑树等等都是这种。
而重量平衡树主要依靠判断左右儿子的数量关系来判断平衡,替罪羊就属于这一类。
旋转平衡树的一大优势就在于支持一些子树操作,因为它一次只旋转部分节点,没有破坏整棵树的性质和结构。而且LCT要用splay虽然我LCT和splay都不会
但是旋转平衡树也有一个缺点,就是面对一个节点维护巨量信息的时候,旋转的复杂度就会很高,如果一个节点维护子树里所有节点的信息,那么旋转的复杂度就会直接爆炸。
替罪羊树则不会出现这种问题,而且替罪羊树相对通俗易懂,比较好写。
其实替罪羊树的自平衡方法非常简单,如果发现一棵子树不平衡,那就把它暴力拍扁成序列,然后再重构成绝对平衡的完全二叉树接回去(也不一定是绝对平衡,可能因为节点数量导致最后一层不满)
那么如何判断是否不平衡呢?
旋转平衡树是不需要判断,每次都旋转来追求均摊复杂度。
但是替罪羊如果不需要判断,每次都重构的话会直接卡成n方,所以显然要找到一个方法。
我们设定一个因子a其实是阿尔法只是懒得打代码
这个a要设定在0.5到1之间
如果一个子树的大小乘以a小于左子树大小或右子树,那么这个子树就不平衡。
其实这个也很好理解,如果左儿子占到了整体了百分之七八十,那肯定不平衡,直接拍扁了好了。
至于这个a到底是多少,一般取0.75就行,取0.7或0.8之间也无所谓。
但是不能太偏。如果太小,那么稍微一偏就重构,重构的次数太多,导致变慢。如果太大,都畸形了还是不重构,那么就影响查询的效率。所以取个中一般是比较好的选择。
那么接下来就是代码。
先上变量。
struct node{
int l,r,x,tot,size,fa,all,ava;
bool del;
}t[1000100];
l和r这里和线段树不一样,不是左右区间,而是左右儿子是谁(其实就是变量模拟指针。)
x表示这个节点代表的权值,tot表示这个权值的数的个数。
size表示这棵子树的大小,fa表示这个节点的父亲,ava表示这棵子树内没被删除的点的个数(available)
all表示这棵子树所有节点的tot的总和。
del表示这个节点有没有被删除。1删除,0还在。
这里写过动态开点线段树的大伙应该知道,手动分配内存是一个很不错的选择。但是那里只是随用随加。这里还不够,我们要做到废物利用,也就是说把删除的点的地址重新分配给新节点,那么这个就需要手写内存池分配函数了。
inline int kk(){
if(ts>0)return ck[ts--];
else return ++len;
}
ck数组存储的删除掉的点的地址,注意,这里的“删除”不是指执行删除操作后不存在的点,而是拍扁成序列时一开始那些不平衡的树。如果还有剩下的就优先分配剩下的,如果没有就新开。
那么这是前前置的基本函数,接下来就开始上操作了。
首先是第一个操作:加点。
大家应该都知道BST中加点的操作,这里也差不多,但是要考虑几个细节,这些细节待会再提。首先是先寻找到表示这个权值的点,然后让这个点的tot+1。
那么寻找函数如下:
int find(int x,int nowp){
if(x<t[nowp].x&&t[nowp].l)return find(x,t[nowp].l);
if(x>t[nowp].x&&t[nowp].r)return find(x,t[nowp].r);
return nowp;
}
寻找表示x这个值的点,当前位置在nowp。由于BST的性质,如果要找的x比当前点小且存在左儿子,那么目标节点就一定在nowp的左儿子里,反之就在右儿子。
但是这里有个地方需要注意,就是它不一定正好返回那个精确的位置,如果表示x的节点不存在,那么就会返回当前跑到的点的位置。所以添加的时候需要判断是否正好在那个表示x的节点,如果在就加上去,如果不在,就说明表示x的节点不存在,需要新建一个节点来表示x。
那么接下来,就是新建节点的函数了。
void build(int x,int nowp,int fa){
t[nowp].l=t[nowp].r=0;
t[nowp].fa=fa,t[nowp].x=x;
t[nowp].del=0;
t[nowp].tot=t[nowp].all=t[nowp].ava=1;
}
虽然一个节点的信息很多,但是真正带进去的参数只有三个。
左右儿子一开始肯定不存在全部设成0。注意这里不能因为定义的是全局变量初始化为0就把这个设成0的这一句去掉,因为有可能新建的节点的地址是之前被删掉的点,而之前被删掉的点有可能有左右儿子,这样就会出现一些奇怪的问题。同理,所有的信息都要重新覆盖一遍。
那么,找到了节点以后,不仅仅是tot+1这么简单,你的all也要改,ava也要改,而且你的父亲,你的爷爷,你的十八辈祖宗由于他们的子树包括你,他们的子树信息也要改,所以简单的一句话是不行的,又要新写一个函数。
void update(int nowp,int a,int b,int c){
if(!nowp)return;
t[nowp].ava+=a;
t[nowp].size+=b;
t[nowp].all+=c;
update(t[nowp].fa,a,b,c);
}
这里也要判断及时跳出去,别陷入0的死循环了。跟子树有关的只有三个信息,所以只用改3处。
那么这些前置函数看完了,就可以来看add函数了。
void add(int x){
if(root==0){
build(x,root=kk(),0);
return;
}
int p=find(x,root);
if(x==t[p].x){
t[p].tot++;
if(t[p].del)t[p].del=0,update(p,1,0,1);
else update(p,0,0,1);
}
else if(x<t[p].x)build(x,t[p].l=kk(),p),update(p,1,1,1);
else build(x,t[p].r=kk(),p),update(p,1,1,1);
find_rebuild(root,x);
}
这里还要特判一下如果是第一次,还要新建根,然后建完根以后实际上就完成了添加的操作(毕竟他没有祖宗,所以不需要update)
然后就是调用find函数找点了。上面也有提到,这里需要判断是否正好在那个表示x的节点。
如果在那个节点,那么还要判断一下这个点是否被删除了。如果被删除了,那么就再改成没删除,然后修改available和all。因为size是包括打上了删除标记的节点,所以size不需要改。如果一开始没被删除,那么就只有all要改。
如果不在那个节点,就表示目前还没有表示x的节点,那么就新建一个。这里新建的时候就要调用内存池了,而且available,all,size都要修改。
另外因为加点可能改变树的亚子,所以要判断一下是否需要重构,这是后话后面再说。
这就是加入。接下来是删除。
删除比较简单,因为本题貌似只会删除存在的点...所以调用find以后也不用判断,直接删就好了。
代码如下:
void del(int x){
int p=find(x,root);
t[p].tot--;
if(!t[p].tot)t[p].del=1,update(p,-1,0,-1);
else update(p,0,0,-1);
find_rebuild(root,x);
}
减完以后,如果这点的tot已经归零,这点就没了,所以我们就打上标记,顺便修改一下available和all。
如果还有,那么只需要修改all。同上,也需要判断是否需要修改子树。
那么接下来就是find_rebuild这个函数了:
void find_rebuild(int nowp,int x){
if((double)max(t[t[nowp].l].size,t[t[nowp].r].size)>(double)t[nowp].size*alpha||(double)t[nowp].size-double(t[nowp].ava)>(double)t[nowp].size*0.4){
re_build(nowp);
return;
}
if(t[nowp].x!=x)find_rebuild(x<t[nowp].x?t[nowp].l:t[nowp].r,x);
}
这里除了上文的判断左右儿子子树大小外,又加了一个剪枝:如果被删掉的点占这个子树的40%还多,那么就可以直接重构,这样可以去掉很多被删掉的点,加快速度释放空间。
另外这里的两个参数,第一个是当前点,第二个是目标点。调用这个函数时必须要从根往下走到目标点,不然会出事。
比如你从目标点往上走到一个点,发现不平衡,重构一遍。又往上走,又发现不平衡,再来一遍。这样可能要反反复复重构好几次,把最坏复杂度又乘上了一个log。但是如果你从上往下走,发现不平衡就一次性全部重构,就不需要往下走了。
那么重构有两步,第一步是dfs中序遍历来拍扁序列,第二步是把序列重构成完全二叉树。那么这两个函数如下:
void dfs_rebuild(int nowp){
if(nowp==0)return;
dfs_rebuild(t[nowp].l);
if(!t[nowp].del){
sl[++tt].x=t[nowp].x;
sl[tt].tot=t[nowp].tot;
}
ck[++ts]=nowp;
dfs_rebuild(t[nowp].r);
}
这个没什么好说的,就是一个普通的中序遍历的函数,这里的序列需要记录两个值,一个是表示的元素x,一个是个数tot。注意一下判断这个点是否被删掉,及时跳出0的死循环,然后拍扁以后把这个地址加进垃圾桶就行。
然后是重构函数re_add:
int re_add(int l,int r,int fa){
if(l>r)return 0;
int mid=l+r>>1,id=kk();
t[id].fa=fa;
t[id].tot=sl[mid].tot;
t[id].x=sl[mid].x;
t[id].l=re_add(l,mid-1,id);
t[id].r=re_add(mid+1,r,id);
t[id].all=t[t[id].l].all+t[t[id].r].all+sl[mid].tot;
t[id].del=0;
t[id].size=t[id].ava=r-l+1;
return id;
}
三个参数,当前要重构的左右区间,以及这个区间重构出来的子树的父亲。
这里也很简单,但是要注意新建点存储信息的时候这个顺序不能乱,不然会出一些错误。
还有统计all的时候别光算左右儿子的,还得算自己的tot。
最后返回整棵子树的根节点。
re_build函数如下:
void re_build(int nowp){
tt=0;
dfs_rebuild(nowp);
if(nowp==root)root=re_add(1,tt,0);
else {
update(t[nowp].fa,0,-t[nowp].size+t[nowp].ava,0);
if(t[t[nowp].fa].l==nowp)t[t[nowp].fa].l=re_add(1,tt,t[nowp].fa);
else t[t[nowp].fa].r=re_add(1,tt,t[nowp].fa);
}
}
nowp是要重构的子树父节点。
tt是序列的指针,但是这里要注意,这里往回接的时候不能接错了,如果是把整棵替罪羊树重构,那么就更新根,如果不是,就需要先update一下,因为重构完了以后没有保留del=1的点,所以要把size中不可用的那一部分减掉。
现在已经知道重构的不是根了,就要判断一下是nowp是父亲的左儿子还是右儿子,别接错了。
这样整个重构的四个函数就讲完了,接下来是add和del后的四个操作。因为它们没有改变树的结构,所以不需要判断是否不平衡。另外本题不会查询到不存在的数,所以也不需要判断节点是否存在,跑到了就统计答案即可。
int findxpm(int x){
int nowp=root;
int ans=1;
while(nowp){
if(x<=t[nowp].x)nowp=t[nowp].l;
else ans+=t[t[nowp].l].all+t[nowp].tot,nowp=t[nowp].r;
}
ans+=t[t[nowp].l].all;
return ans;
}
这个是找x的排名。这个就和BST完全一样了。注意最后输出的时候answer要+1,因为排名没有第0名,只有第一名。
int findpmx(int x){
int nowp=root;
while(nowp){
if(x<=t[t[nowp].l].all)nowp=t[nowp].l;
else {
x-=t[t[nowp].l].all;
if(x<=t[nowp].tot){
return t[nowp].x;
}
x-=t[nowp].tot;
nowp=t[nowp].r;
}
}
}
查找排名为x的数。这也和BST一样。
查找前驱和后继就非常简单了,找前驱就是找这个数的排名前一位的数,找后继就是找这个数+1的数的排名
完整代码如下:其实也不是很长,不如树剖和一些图论套dp题
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cstdlib>
using std::max;
struct node{
int l,r,x,tot,size,fa,all,ava;
bool del;
}t[1000100];
struct sll{
int x,tot;
}sl[1000100];
int tt;
int len,n,root;
int ck[1000100],ts;
double alpha=0.75;
void build(int x,int nowp,int fa){
t[nowp].l=t[nowp].r=0;
t[nowp].fa=fa,t[nowp].x=x;
t[nowp].del=0;
t[nowp].tot=t[nowp].all=t[nowp].ava=1;
}
inline int kk(){
if(ts>0)return ck[ts--];
else return ++len;
}
void update(int nowp,int a,int b,int c){
if(!nowp)return;
t[nowp].ava+=a;
t[nowp].size+=b;
t[nowp].all+=c;
update(t[nowp].fa,a,b,c);
}
int find(int x,int nowp){
if(x<t[nowp].x&&t[nowp].l)return find(x,t[nowp].l);
if(x>t[nowp].x&&t[nowp].r)return find(x,t[nowp].r);
return nowp;
}
void dfs_rebuild(int nowp){
if(nowp==0)return;
dfs_rebuild(t[nowp].l);
if(!t[nowp].del){
sl[++tt].x=t[nowp].x;
sl[tt].tot=t[nowp].tot;
}
ck[++ts]=nowp;
dfs_rebuild(t[nowp].r);
}
int re_add(int l,int r,int fa){
if(l>r)return 0;
int mid=l+r>>1,id=kk();
t[id].fa=fa;
t[id].tot=sl[mid].tot;
t[id].x=sl[mid].x;
t[id].l=re_add(l,mid-1,id);
t[id].r=re_add(mid+1,r,id);
t[id].all=t[t[id].l].all+t[t[id].r].all+sl[mid].tot;
t[id].del=0;
t[id].size=t[id].ava=r-l+1;
return id;
}
void re_build(int nowp){
tt=0;
dfs_rebuild(nowp);
if(nowp==root)root=re_add(1,tt,0);
else {
update(t[nowp].fa,0,-t[nowp].size+t[nowp].ava,0);
if(t[t[nowp].fa].l==nowp)t[t[nowp].fa].l=re_add(1,tt,t[nowp].fa);
else t[t[nowp].fa].r=re_add(1,tt,t[nowp].fa);
}
}
void find_rebuild(int nowp,int x){
if((double)max(t[t[nowp].l].size,t[t[nowp].r].size)>(double)t[nowp].size*alpha||
(double)t[nowp].size-double(t[nowp].ava)>(double)t[nowp].size*0.4){
re_build(nowp);
return;
}
if(t[nowp].x!=x)find_rebuild(x<t[nowp].x?t[nowp].l:t[nowp].r,x);
}
void add(int x){
if(root==0){
build(x,root=kk(),0);
return;
}
int p=find(x,root);
if(x==t[p].x){
t[p].tot++;
if(t[p].del)t[p].del=0,update(p,1,0,1);
else update(p,0,0,1);
}
else if(x<t[p].x)build(x,t[p].l=kk(),p),update(p,1,1,1);
else build(x,t[p].r=kk(),p),update(p,1,1,1);
find_rebuild(root,x);
}
void del(int x){
int p=find(x,root);
t[p].tot--;
if(!t[p].tot)t[p].del=1,update(p,-1,0,-1);
else update(p,0,0,-1);
find_rebuild(root,x);
}
int findrank(int x){
int nowp=root;
int ans=1;
while(nowp){
if(x<=t[nowp].x)nowp=t[nowp].l;
else ans+=t[t[nowp].l].all+t[nowp].tot,nowp=t[nowp].r;
}
ans+=t[t[nowp].l].all;
return ans;
}
int findkth(int x){
int nowp=root;
while(nowp){
if(x<=t[t[nowp].l].all)nowp=t[nowp].l;
else {
x-=t[t[nowp].l].all;
if(x<=t[nowp].tot){
return t[nowp].x;
}
x-=t[nowp].tot;
nowp=t[nowp].r;
}
}
}
int main(){
scanf("%d",&n);
while(n--){
int id,x;
scanf("%d%d",&id,&x);
if(id==1)add(x);
if(id==2)del(x);
if(id==3)printf("%d\n",findrank(x));
if(id==4)printf("%d\n",findkth(x));
if(id==5)printf("%d\n",findkth(findrank(x)-1));
if(id==6)printf("%d\n",findkth(findrank(x+1)));
}
return 0;
}