P3369 【模板】普通平衡树 题解

· · 题解

题目传送门。

前置知识

学习 fhq 前,你要学会 二叉搜索树。

实现 fhq

存储与创建

在 fhq 中,为了实现平衡,需要用到随机数来作为索引,有序序列构造二叉搜索树会使其退化成一条链,那么打乱顺序后也就不再是一条链了。

所以节点中需要存储左右子节点、内容、索引,有时要访问排名,还要存储子树大小。

创建新节点就只用初始化就行了。

struct node {
    int l,r;
    int dat,key;
    int size;
} treap[100001];
int cnt;
int make(int val) {
    treap[++cnt].dat=val;
    treap[cnt].key=rand();
    treap[cnt].size=1;
    return cnt;
}

分裂

因为二叉搜索树的性质,若父节点小于 val,那么左子树的所有节点都小于 val;若父节点大于 val,那么右子树的所有节点都大于 val。所以对于以上两种情况,分别遍历右子树和左子树就行。

void up(int pos) {//记得更新子树大小
    treap[pos].size=treap[treap[pos].l].size+treap[treap[pos].r].size+1;//记得+1
}
void cut(int val,int pos,int &c1,int &c2) {
    if(!pos)c1=c2=0;
    else {
        if(treap[pos].dat<=val) {
            c1=pos;
            cut(val,treap[pos].r,treap[pos].r,c2);
        } else {
            c2=pos;
            cut(val,treap[pos].l,c1,treap[pos].l);
        }
        up(pos);
    }
}

合并

将根为 x 的搜索树与根为 y 的搜索树合并(这里 x 中的元素最大值必需小于等于 y 中的元素最小值),我们的索引就起作用了。判断 x 的索引与 y 的索引的大小关系(可以是各种关系,只要满足随机就行)然后决定是将 x 接到 y 左下方还是 y 接到 x 右下方即可。

int join(int r1,int r2) {
    if(!r1||!r2)return r1+r2;
    else {
        if(treap[r1].key>treap[r2].key) {
            treap[r1].r=join(treap[r1].r,r2);
            up(r1);
            return r1;
        } else {
            treap[r2].l=join(r1,treap[r2].l);
            up(r2);
            return r2;
        }
    }
}

插入元素

分为两步:

  1. 将原树按 val 切开为 x,y

  2. 合并 x,val,y

    void take(int val) {
    cut(val,root,x,y);
    root=join(join(x,make(val)),y);
    }

    删除元素

    分为四步:

  3. 将原树按 val 切开为 x,y

  4. xval-1 切开为 x,z

  5. 合并 z 的左子树和右子树。

  6. 合并 x,y,z

    void del(int val) {
    cut(val,root,x,y);
    cut(val-1,x,x,z);
    z=join(treap[z].l,treap[z].r);
    root=join(join(x,z),y);
    }

    查看排名

    分为三步:

  7. 将原树按 val-1 切开为 x,y

  8. 此时 xsize 即为 val 的排名。

  9. 合并 x,y

    int rank(int val) {
    cut(val-1,root,x,y);
    int pos=treap[pos].size+1;
    root=join(x,y);
    return pos;
    }

    根据排名查看数据

    由于本质上就是二叉搜索树,所以使用和权值线段树一样的方法即可。

    int num(int val) {
    int pos=root;
    while(pos) {
        if(treap[treap[pos].l].size+1==val)return treap[pos].dat;
        else if(treap[treap[pos].l].size+1>val)pos=treap[pos].l;
        else {
            val-=treap[treap[pos].l].size+1;
            pos=treap[pos].r;
        }
    }
    return 0;
    }

    前驱与后缀

    前驱为按 val-1 切开为 x,yx 的最右端点,后缀为按 val 切开时 y 的最左端点。

    int pre(int val)
    {
    cut(val-1,root,x,y);
    int pos=x;
    while(treap[pos].r)pos=treap[pos].r;
    root=join(x,y);
    return treap[pos].dat;
    }
    int last(int val)
    {
    cut(val,root,x,y);
    int pos=y;
    while(treap[pos].l)pos=treap[pos].l;
    root=join(x,y);
    return treap[pos].dat;
    }

    完整代码

    #include<bits/stdc++.h>
    #include<time.h>
    using namespace std;
    struct node {
    int l,r;
    int dat,key;
    int size;
    } treap[100001];
    int cnt,root;
    int x,y,z;
    int n,opt,a;
    int make(int val) {
    treap[++cnt].dat=val;
    treap[cnt].key=rand();
    treap[cnt].size=1;
    return cnt;
    }
    void up(int pos) {
    treap[pos].size=treap[treap[pos].l].size+treap[treap[pos].r].size+1;
    }
    void cut(int val,int pos,int &c1,int &c2) {
    if(!pos)c1=c2=0;
    else {
        if(treap[pos].dat<=val) {
            c1=pos;
            cut(val,treap[pos].r,treap[pos].r,c2);
        } else {
            c2=pos;
            cut(val,treap[pos].l,c1,treap[pos].l);
        }
        up(pos);
    }
    }
    int join(int r1,int r2) {
    if(!r1||!r2)return r1+r2;
    else {
        if(treap[r1].key>treap[r2].key) {
            treap[r1].r=join(treap[r1].r,r2);
            up(r1);
            return r1;
        } else {
            treap[r2].l=join(r1,treap[r2].l);
            up(r2);
            return r2;
        }
    }
    }
    void take(int val) {
    cut(val,root,x,y);
    root=join(join(x,make(val)),y);
    }
    void del(int val) {
    cut(val,root,x,y);
    cut(val-1,x,x,z);
    z=join(treap[z].l,treap[z].r);
    root=join(join(x,z),y);
    }
    int wrank(int val) {
    cut(val-1,root,x,y);
    int pos=treap[x].size+1;
    root=join(x,y);
    return pos;
    }
    int num(int val) {
    int pos=root;
    while(pos) {
        if(treap[treap[pos].l].size+1==val)return treap[pos].dat;
        else if(treap[treap[pos].l].size+1>val)pos=treap[pos].l;
        else {
            val-=treap[treap[pos].l].size+1;
            pos=treap[pos].r;
        }
    }
    return 0;
    }
    int last(int val)
    {
    cut(val,root,x,y);
    int pos=y;
    while(treap[pos].l)pos=treap[pos].l;
    root=join(x,y);
    return treap[pos].dat;
    }
    int pre(int val)
    {
    cut(val-1,root,x,y);
    int pos=x;
    while(treap[pos].r)pos=treap[pos].r;
    root=join(x,y);
    return treap[pos].dat;
    }
    int main() {
    srand(time(0));
    cin>>n;
    for(int i=1;i<=n;i++)
    {
        cin>>opt>>a;
        if(opt==1)take(a);
        if(opt==2)del(a);
        if(opt==3)cout<<wrank(a)<<endl;
        if(opt==4)cout<<num(a)<<endl;
        if(opt==5)cout<<pre(a)<<endl;
        if(opt==6)cout<<last(a)<<endl;
    }
    }