【模板】平衡树

· · 算法·理论

二叉搜索树

#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
struct nod
{
    int val,cnt,siz,lc,rc;
    //val表示当前点的权值
    //cnt表示当前权值重复出现的次数(重数)
    //siz表示子树大小(子树中cnt之和)
    //lc表示左子节点的编号,不存在则用0表示
    //rc表示右子节点的编号,不存在则用0表示
};
struct BST
{
    nod d[N];
    int tot,root;
    void insert(int &cur,int x) //在以cur为根的子树中,插入一个权值为x的元素 
    {
        if(!cur)
        {
            cur = ++tot;
            d[cur]={x,1,1,0,0};
            return ;
        }
        d[cur].siz++;
        if(d[cur].val==x) d[cur].cnt++;
        else if(x<d[cur].val) insert(d[cur].lc,x);
        else insert(d[cur].rc,x); 
    }
    int findRank(int cur,int x) //查找x在以cur为根的子树中的排名(小于x的元素数量+1) 
    {
        if(!cur) return 1;
        if(d[cur].val==x) return d[d[cur].lc].siz+1;
        if(x<d[cur].val) return findRank(d[cur].lc,x);
        else return d[d[cur].lc].siz + d[cur].cnt + findRank(d[cur].rc,x);
    }
    int kth(int cur,int k) //查找以cur为根的子树中,排名为k的元素大小(从小到大排序后(非严格)第k小元素) 
    {
        if(k<=d[d[cur].lc].siz) return kth(d[cur].lc,k);
        if(k<=d[d[cur].lc].siz + d[cur].cnt) return d[cur].val;
        return kth(d[cur].rc , k - d[d[cur].lc].siz - d[cur].cnt);
    }

    bool find(int x) //查找x是否存在 
    {
        return kth(root,findRank(root,x))==x;
    }
    int findcnt(int x) //查找x的出现次数 
    {
        return findRank(root,x+1)-findRank(root,x);
    }
    int findcnt(int l,int r) //查找l~r范围内元素出现次数 
    {
        return findRank(root,r+1)-findRank(root,l);
    }

    int pre(int x) //找前驱,最大的严格小于x的元素,x可以在二叉查找树中的权值中不存在 
    {
        return kth(root,findRank(root,x)-1); 
    }
    int suc(int x) //找后继,最小的严格大于x的元素,x可以在二叉查找树中的权值中不存在 
    {
        return kth(root,findRank(root,x+1));
    }

    bool erase(int cur,int x)
    {
        if(!cur) return false;
        if(d[cur].val==x)
        {
            if(d[cur].cnt>0)
            {
                d[cur].cnt--;
                d[cur].siz--;
                return true;
            }
            else return false;
        }
        if(x<d[cur].val)
        {
            if(erase(d[cur].lc,x))
            {
                d[cur].siz--;
                return true;
            }
            else return false;
        }
        else
        {
            if(erase(d[cur].rc,x))
            {
                d[cur].siz--;
                return true;
            }
            else return false;
        }
    }
};
int main()
{

    return 0;
}

旋转式Treap

#include <bits/stdc++.h>
using namespace std;
struct nod{//存储二叉搜索树节点的一个结构体
    int val, cnt, siz, ch[2], rank;
    //val表示结点权值
    //cnt表示val重复出现的次数
    //siz表示二叉搜索树中以当前节点为根的子树大小(子树cnt之和)
    //ch[0]表示二叉搜索树中当前节点的左儿子编号
    //ch[1]表示二叉搜索树中当前节点的右儿子编号
};
struct BST{
    nod d[101000];
    int tot, root;
    //tot表示当前使用节点的数量 使用的节点为d[1]~d[tot]
    //root表示当前根节点的编号
    void upd_size(int cur){//更新以cur节点的子树大小
        d[cur].siz = d[cur].cnt;
        d[cur].siz += d[d[cur].ch[0]].siz;
        d[cur].siz += d[d[cur].ch[1]].siz;
    }
    void rotate(int &cur, bool dir){//将以cur为根的子树,方向为dir(0或1)的儿子旋转到原先cur的位置 
        int tmp = d[cur].ch[dir];
        d[cur].ch[dir] = d[tmp].ch[!dir];
        d[tmp].ch[!dir] = cur;
        upd_size(cur);
        upd_size(tmp);
        cur = tmp;
    }
    void insert(int &cur, int x){//在以cur为根的子树中插入一个数x
        if (!cur){//cur==0表示空节点
            cur = ++tot; 
            d[cur] = {x, 1, 1, {0, 0}, rand()};
            return;
        }
        d[cur].siz++; 
        if (d[cur].val == x){
            d[cur].cnt++;
        }else if (x < d[cur].val){
            insert(d[cur].ch[0], x);
            if (d[d[cur].ch[0]].rank < d[cur].rank){
                rotate(cur, 0);
            }
        }else{
            insert(d[cur].ch[1], x);
            if (d[d[cur].ch[1]].rank < d[cur].rank){
                rotate(cur, 1);
            }
        }
    }
    bool erase(int cur, int x){//成功找到x并删除,返回true,否则返回false
        if (!cur) return false;
        if (d[cur].val == x){
            if (d[cur].cnt > 0){
                d[cur].cnt--;
                d[cur].siz--;
                return true;
            }
        }else if (x < d[cur].val){
            if (erase(d[cur].ch[0], x)){
                d[cur].siz--;
                return true;
            }
        }else{
            if (erase(d[cur].ch[1], x)){
                d[cur].siz--;
                return true;
            }
        }
        return false;
    }
    int findRank(int cur, int x){//在以cur为根的子树中找x的排名(比x小的数数量+1)
        if (!cur) return 1;
        if (d[cur].val == x)
            return d[d[cur].ch[0]].siz + 1;
        if (x < d[cur].val )
            return findRank(d[cur].ch[0], x);
        else
            return d[d[cur].ch[0]].siz + d[cur].cnt + findRank(d[cur].ch[1], x);
    }
    int kth(int cur, int k){//在以cur为根的子树中求第k小的数(排名为k的数)
        if (k <= d[d[cur].ch[0]].siz)
            return kth(d[cur].ch[0], k);
        if (k <= d[d[cur].ch[0]].siz + d[cur].cnt)
            return d[cur].val;
        return kth(d[cur].ch[1], k - d[d[cur].ch[0]].siz - d[cur].cnt);
    }
    int pre(int x){//x的前驱
        return kth(root, findRank(root, x) - 1);
    }
    int suc(int x){//x的后继
        return kth(root, findRank(root, x + 1));
    }
}T;
int q, op, x; 
int main(){
    srand(time(0));//初始化随机数种子
    cin >> q;
    while (q--){
        cin >> op >> x;
        if (op == 1){
            T.insert(T.root, x);
        }else if (op == 2){
            T.erase(T.root, x);
        }else if (op == 3){
            cout << T.findRank(T.root, x) << endl;
        }else if (op == 4){
            cout << T.kth(T.root, x) << endl;
        }else if (op == 5){
            cout << T.pre(x) << endl;
        }else if (op == 6){
            cout << T.suc(x) << endl;
        }
    }
    return 0;
}

非旋转式Treap

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+5;
struct BST{
    int tot,root;
    int val[N]; //val[i] 表示节点i的点权 
    int pri[N]; //pri[i] 表示节点i的优先级 
    int ch[N][2];
    int siz[N];

    int build(int k) //创建一个新的权值为k的节点,返回该点的编号
    {
        int u=++tot;
        val[u]=k;
        pri[u]=rand();
        ch[u][0]=ch[u][1]=0;
        siz[u]=1;
        return u;
    }
    void maintain(int u) //重新计算以u为根的子树大小
    {
        if(u) siz[u]=siz[ch[u][0]] + siz[ch[u][1]] + 1;
    }
    void spilt(int u,int k,int &l,int &r)
    //把以u为根的Treap分裂为两棵Treap l和r,使得l树的所有权值<k,r树的所有权值>=k 
    {
        if(u==0)
        {
            l=r=0;
            return ;
        }
        if(k<=val[u])
        {
            spilt(ch[u][0],k,l,ch[u][0]);
            r=u;
        }
        else
        {
            l=u;
            spilt(ch[u][1],k,ch[u][1],r);
        }
        maintain(u);
    }
    int merge(int u,int v) //树u的val <= 树v的val,将u和v合并,返回新的根节点 
    {
        if(u==0 || v==0) return u|v;
        if(pri[u]<pri[v])
        {
            ch[u][1]=merge(ch[u][1],v);
            maintain(u);
            return u;
        }
        else
        {
            ch[v][0]=merge(u,ch[v][0]);
            maintain(v);
            return v;
        }
    }
    void insert(int &root,int k) //插入一个新的元素k 
    {
        int l,r;
        spilt(root,k,l,r);
        root=merge(merge(l,build(k)),r); //merge(l,merge(build(k),r));
    }
    void erase(int &root,int k) //删除一个元素k 
    {
        int l,u,r;
        spilt(root,k,l,r);
        spilt(r,k+1,u,r);
        root=merge(merge(l,merge(ch[u][0],ch[u][1])),r);
    }
    int rank(int root,int k) //找k在root中的排名(<k的元素数量) 
    {
        int u=root;
        int ret=0;
        while(u)
        {
            if(val[u]<k)
            {
                ret+=siz[ch[u][0]]+1;
                u=ch[u][1];
            }
            else u=ch[u][0];
        }
        return ret;
    }
    int select(int root,int k) //找root为根的子树排序后第k+1小的元素 
    {
        int u=root;
        while(u)
        {
            if(siz[ch[u][0]]==k) break;
            if(k<siz[ch[u][0]]) u=ch[u][0];
            else k-=siz[ch[u][0]]+1 , u=ch[u][1]; 
        }
        return val[u];
    }
    int pre(int x)
    {
        return select(root,rank(root,x)-1);
    }
    int suc(int x)
    {
        return select(root,rank(root,x+1));
    }
};
int main()
{
    srand(time(0));

    return 0;
}

拓展

以下代码也可以执行平衡树的操作,不过与模板中的函数调用方式不同。

#include <bits/stdc++.h>
#include<bits/extc++.h>
using namespace std;
__gnu_pbds::tree<pair<int, int> , __gnu_pbds::null_type, less<pair<int, int>>,
                 __gnu_pbds::rb_tree_tag,
                 __gnu_pbds::tree_order_statistics_node_update>
    T;
int main()
{

    return 0;
}