常用平衡树

· · 个人记录

Problem : 普通平衡树

Treap

#include<bits/stdc++.h>
using namespace std;
mt19937 rand_num(time(0));

struct Treap{
    int val,cnt,ls,rs,siz,dat;
}T[100050];

const int maxn=2147483647,minn=-2147483647;
int n,x,op,root=1,k;

void add(){
    T[++k].val=x;
    T[k].dat=rand_num();
    T[k].cnt=1;
    T[k].siz=1;
}

void update(int no){
    T[no].siz=T[no].cnt+T[T[no].ls].siz+T[T[no].rs].siz;
}

void zig(int &no){
    int tmp=T[no].ls;
    T[no].ls=T[tmp].rs;
    T[tmp].rs=no;
    no=tmp;
    update(T[no].rs);
    update(no);
}

void zag(int &no){
    int tmp=T[no].rs;
    T[no].rs=T[tmp].ls;
    T[tmp].ls=no;
    no=tmp;
    update(T[no].ls);
    update(no);
}

void insert(int &no){
    if(no==0){
        add();no=k;return;
    }if(T[no].val==x){
        T[no].cnt++;
        T[no].siz++;
        return;
    }if(T[no].val>x){
        insert(T[no].ls);
        if(T[no].dat<T[T[no].ls].dat)zig(no);
    }else{
        insert(T[no].rs);
        if(T[no].dat<T[T[no].rs].dat)zag(no);
    }update(no);
}

void del(int &no){
    if(no==0)return;
    if(T[no].val==x){
        if(T[no].cnt>1){
            T[no].cnt--;
            update(no);
            return;
        }if(T[no].ls||T[no].rs){
            if(T[no].ls!=0&&(T[no].rs==0||T[T[no].ls].dat>=T[T[no].rs].dat)){
                zig(no);del(T[no].rs);
            }else{
                zag(no);del(T[no].ls);
            }update(no);
        }else no=0;
        return;
    }if(T[no].val>x)del(T[no].ls);
    else del(T[no].rs);
    update(no);
}

int find_rank(int no){
    if(no==0)return 0;
    if(T[no].val==x)return T[T[no].ls].siz;
    if(T[no].val<x)return T[T[no].ls].siz+T[no].cnt+find_rank(T[no].rs);
    return find_rank(T[no].ls);
}

int find_num(int no){
    if(no==0)return maxn;
    if(T[T[no].ls].siz>=x)return find_num(T[no].ls);
    if(T[T[no].ls].siz+T[no].cnt>=x)return T[no].val;
    x-=T[T[no].ls].siz+T[no].cnt;
    return find_num(T[no].rs);
}

int last_num(int no){
    if(no==0)return minn;
    if(T[no].val>=x)return last_num(T[no].ls);
    return max(T[no].val,last_num(T[no].rs));
}

int next_num(int no){
    if(no==0)return maxn;
    if(T[no].val<=x)return next_num(T[no].rs);
    return min(T[no].val,next_num(T[no].ls));
}

int main()
{
    srand(time(0));
    x=maxn;
    insert(root);
    x=minn;
    insert(root);
    update(root);
    cin>>n;
    while(n--){
        cin>>op>>x;
        switch(op){
            case 1:{
                insert(root);
                break;
            }case 2:{
                del(root);
                break;
            }case 3:{
                cout<<find_rank(root)<<endl;
                break;
            }case 4:{
                x++;
                cout<<find_num(root)<<endl;
                break;
            }case 5:{
                cout<<last_num(root)<<endl;
                break;
            }case 6:{
                cout<<next_num(root)<<endl;
                break;
            }default:break;
        }
    }
    return 0;
}

Splay

#include<iostream>
using namespace std;

struct node{
    int val,maxn,minn;
    int cnt,siz;
    int left,right,fa;
}T[1000050];

int n,op,x,root,nth;

void update(int no){
    T[no].siz=T[no].cnt+T[T[no].left].siz+T[T[no].right].siz;
    T[no].maxn=max(T[no].val,max(T[T[no].left].maxn,T[T[no].right].maxn));
    T[no].minn=min(T[no].val,min(T[T[no].left].minn,T[T[no].right].minn));
}

bool son(int no){
    return T[T[no].fa].right==no;
}

void rotate(int no){
    int tmp=T[no].fa,tmp2=T[tmp].fa;
    bool flag=son(tmp);
    if(son(no)==0)
        T[tmp].left=T[no].right,T[T[no].right].fa=tmp,
        T[no].right=tmp,T[tmp].fa=no;
    else
        T[tmp].right=T[no].left,T[T[no].left].fa=tmp,
        T[no].left=tmp,T[tmp].fa=no;
    T[no].fa=tmp2;
    update(tmp);
    update(no);
    if(tmp2==0)return;
    if(flag==0)T[tmp2].left=no;
    else T[tmp2].right=no;
}

void splay(int no,int target=0){
    while(T[no].fa!=target){
        if(T[T[no].fa].fa!=target&&son(no)==son(T[no].fa))rotate(T[no].fa);
        rotate(no);
    }
    if(target==0)root=no;
}

void insert(int &no,int val,int lst){
    if(no==0){
        no=++nth;
        T[no].val=val;
        T[no].fa=lst;
        T[no].siz=T[no].cnt=1;
        splay(no);
        return;
    }
    if(T[no].val==val){
        ++T[no].siz;
        ++T[no].cnt;
        splay(no);
        return;
    }
    if(T[no].val>val)insert(T[no].left,val,no);
    else insert(T[no].right,val,no);
}

int kth_num(int no,int rk){
    if(no==0)return 0;
    if(T[T[no].left].siz>=rk)return kth_num(T[no].left,rk);
    if(T[T[no].left].siz<rk&&T[T[no].left].siz+T[no].cnt>=rk){
        splay(no);
        return T[no].val;
    }
    return kth_num(T[no].right,rk-T[T[no].left].siz-T[no].cnt);
}

int pre_num(int no,int val){
    if(no==0)return -2147483647;
    if(T[no].val>=val)return pre_num(T[no].left,val);
    if(T[T[no].right].minn<val)return pre_num(T[no].right,val);
    splay(no);
    return T[no].val;
}

int nxt_num(int no,int val){
    if(no==0)return 2147483647;
    if(T[no].val<=val)return nxt_num(T[no].right,val);
    if(T[T[no].left].maxn>val)return nxt_num(T[no].left,val);
    splay(no);
    return T[no].val;
}

int Rank(int val){
    pre_num(root,val);
    return T[root].cnt+T[T[root].left].siz;
}

void bl(int no){
    if(no==0)return;
    bl(T[no].left);
    for(int i=1;i<=T[no].cnt;i++)cout<<T[no].val<<" ";
    bl(T[no].right);
}

void del(int val){
    pre_num(root,val+1);
    if(T[root].cnt>1){
        --T[root].cnt;
        --T[root].siz;
        return;
    }
    int L=T[root].left,R=T[root].right;
    T[R].fa=0;
    root=R;
    kth_num(root,1);
    T[root].left=L;
    T[L].fa=root;
}

int main()
{
    cin>>n;
    T[0].maxn=-2147483647;
    T[0].minn=2147483647;
    insert(root,-2147483647,0);
    insert(root,2147483647,0);
    while(n--){
        cin>>op>>x;
        switch(op){
            case 1:{
                insert(root,x,0);
                break;
            }
            case 2:{
                del(x);
                break;
            }
            case 3:{
                cout<<Rank(x)<<endl;
                break;
            }
            case 4:{
                cout<<kth_num(root,x+1)<<endl;
                break;
            }
            case 5:{
                cout<<pre_num(root,x)<<endl;
                break;
            }
            case 6:{
                cout<<nxt_num(root,x)<<endl;
                break;
            }
        }
    }
    return 0;
}

FHQ Treap

#include<iostream>
#include<random>
using namespace std;

mt19937 rand_num(time(0));
class FHQ_Treap{
    private:
        struct node{
            int val,dat,left,right,siz;
        }Tr[100050];
        int root,x,nth;
    public:
        void update(int no){
            if(no==0)return;
            Tr[no].siz=Tr[Tr[no].left].siz+Tr[Tr[no].right].siz+1;
        }
        void split(int no,int &L,int &R){
            if(no==0){
                L=0;R=0;
                return;
            }
            int tmp1,tmp2;
            if(Tr[no].val<=x){
                L=no;
                split(Tr[no].right,tmp1,tmp2);
                Tr[L].right=tmp1;
                R=tmp2;
                update(L);
                update(R);
            }else{
                R=no;
                split(Tr[no].left,tmp1,tmp2);
                Tr[R].left=tmp2;
                L=tmp1;
                update(L);
                update(R);
            }
        }
        void merge(int L,int R,int &Root){
            if(L==0||R==0){
                Root=max(L,R);
                return;
            }
            int tmp;
            if(Tr[L].dat>=Tr[R].dat){
                Root=L;
                merge(Tr[L].right,R,tmp);
                Tr[L].right=tmp;
                update(Root);
            }else{
                Root=R;
                merge(L,Tr[R].left,tmp);
                Tr[R].left=tmp;
                update(Root);
            }
        }
        int new_element(int x){
            ++nth;
            Tr[nth].val=x;
            Tr[nth].dat=rand_num();
            Tr[nth].siz=1;
            return nth;
        }
        void insert(int val){
            x=val;
            int n=new_element(val),L,R;
            split(root,L,R);
    //      cout<<Tr[0].siz<<" "<<Tr[0].left<<" "<<Tr[0].right<<endl;
            merge(L,n,root);
    //      cout<<Tr[0].siz<<" "<<Tr[0].left<<" "<<Tr[0].right<<endl;
            merge(root,R,root);
    //      cout<<Tr[0].siz<<" "<<Tr[0].left<<" "<<Tr[0].right<<endl;
        }
        void del(int val){
            x=val-1;
            int rt1,rt2,rt3;
            split(root,rt1,rt2);
            x=val;
            split(rt2,rt2,rt3);
            merge(Tr[rt2].left,Tr[rt2].right,rt2);
            merge(rt1,rt2,root);
            merge(root,rt3,root);
        }
        void bl(int no){
            if(no==0)return;
            bl(Tr[no].left);
            cout<<Tr[no].val<<" ";
            bl(Tr[no].right);
        }
        void check(){
            bl(root);
        }
        int kth(int no){
            if(no==0)return 0;
            if(x<=Tr[Tr[no].left].siz)return kth(Tr[no].left);
            if(x-Tr[Tr[no].left].siz==1)return Tr[no].val;
            x-=Tr[Tr[no].left].siz+1;
            return kth(Tr[no].right);
        }
        int kth_num(int val){
            x=val;
            return kth(root);
        }
        int rank(int no){
            if(no==0)return 0;
            if(Tr[no].val<x)return Tr[Tr[no].left].siz+1+rank(Tr[no].right);
            return rank(Tr[no].left);
        }
        int find_rank(int val){
            x=val;
            return rank(root)+1;
        }
        int pre(int no){
            if(no==0)return -2147483647;
            if(Tr[no].val<x)return max(Tr[no].val,pre(Tr[no].right));
            return pre(Tr[no].left);
        }
        int pre_num(int val){
            x=val;
            return pre(root);
        }
        int nxt(int no){
            if(no==0)return 2147483647;
            if(Tr[no].val>x)return min(Tr[no].val,nxt(Tr[no].left));
            return nxt(Tr[no].right);
        }
        int nxt_num(int val){
            x=val;
            return nxt(root);
        }
        void details(){
            cout<<Tr[0].siz<<" "<<Tr[0].left<<" "<<Tr[0].right<<endl;
            cout<<"root : "<<root<<endl;
            for(int i=1;i<=nth;i++)
                cout<<i<<" : "<<Tr[i].left<<" "<<Tr[i].right<<" "<<Tr[i].siz<<endl;
        }
}T;

int n,op,x;

int main()
{
    cin>>n;
    while(n--){
        cin>>op>>x;
        switch(op){
            case 1:{
                T.insert(x);
                break;
            }
            case 2:{
                T.del(x);
                break;
            }
            case 3:{
                cout<<T.find_rank(x)<<endl;
                break;
            }
            case 4:{
                cout<<T.kth_num(x)<<endl;
                break;
            }
            case 5:{
                cout<<T.pre_num(x)<<endl;
                break;
            }
            case 6:{
                cout<<T.nxt_num(x)<<endl;
                break;
            }
            case 7:{
                T.check();puts("");
                break;
            }
            default:{
                T.details();puts("");
                break;
            }
        }
    }
    return 0;
}

0-1 Trie

#include<cstdio>
#include<cstring>
using namespace std;

const int maxn=2147483647;

class Trie{
    private:
        struct node{
            int left,right,siz;
        }Tr[2500005];
        int root,nth,len;
        bool bit[26];
    public:
        void transform(int x){
            memset(bit,0,sizeof(bit));
            int cnt=0,op;
            if(x<0)x=-x,op=1;
            else bit[len]=1,op=0;
            while(x){
                bit[++cnt]=(x&1)^op;
                x>>=1;
            }
            while(op&&cnt<len-1)bit[++cnt]=1;
        }
        void prework(){
            len=25;root=++nth;
        }
        void insert(int x){
            transform(x);
            int p=root;
            Tr[p].siz++;
            for(int i=len;i>=1;i--){
                if(bit[i]==0){
                    if(Tr[p].left==0)Tr[p].left=++nth;
                    p=Tr[p].left;
                }else{
                    if(Tr[p].right==0)Tr[p].right=++nth;
                    p=Tr[p].right;
                }Tr[p].siz++;
            }
        }
        void del(int x){
            transform(x);
            int p=root;
            Tr[p].siz--;
            for(int i=len;i>=1;i--){
                if(bit[i]==0){
                    if(Tr[p].left==0)Tr[p].left=++nth;
                    p=Tr[p].left;
                }else{
                    if(Tr[p].right==0)Tr[p].right=++nth;
                    p=Tr[p].right;
                }Tr[p].siz--;
            }
        }
        int kth_num(int x){
            if(x<=0)return -maxn;
            if(x>Tr[root].siz)return maxn;
            int p=root,ret=0,op;
            for(int i=len;i>=1;i--){
                ret<<=1;
                if(Tr[Tr[p].left].siz>=x&&Tr[p].left){
                    p=Tr[p].left;
                    if(i==len)op=-1;
                    else if(op==-1)ret|=1;
                }else{
                    x-=Tr[Tr[p].left].siz;
                    p=Tr[p].right;
                    if(i==len)op=1;
                    else if(op==1)ret|=1;
                }
            }return ret*op;
        }
        int rank(int x){
            transform(x);
            int p=root,ret=0;
            for(int i=len;i>=1;i--){
                if(bit[i]==0)p=Tr[p].left;
                else ret+=Tr[Tr[p].left].siz,p=Tr[p].right;
            }return ret+1;
        }
        int pre_num(int x){
            return kth_num(rank(x)-1);
        }
        int nxt_num(int x){
            return kth_num(rank(x+1));
        }
}T;

inline int read(){
    int n=0,f=1;
    char c=getchar();
    while(c<'0' || c>'9'){
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0' && c<='9'){
        n=(n<<3)+(n<<1)+(c^48);
        c=getchar();
    }
    return n*f;
}
void write(int x){
    if(x<0){
        putchar('-');
        x=-x;
    }
    if(x>9)write(x/10);
    putchar(x%10^48);
    return;
}

int n,op,x;

int main()
{
    T.prework();
    n=read();
    while(n--){
        op=read();x=read();
        switch(op){
            case 1:{
                T.insert(x);
                break;
            }case 2:{
                T.del(x);
                break;
            }case 3:{
                write(T.rank(x));puts("");
                break;
            }case 4:{
                write(T.kth_num(x));puts("");
                break;
            }case 5:{
                write(T.pre_num(x));puts("");
                break;
            }case 6:{
                write(T.nxt_num(x));puts("");
                break;
            }default:break;
        }
    }
    return 0;
}

权值线段树

#include<iostream>
using namespace std;

struct SegmentTree{
    int left,right,siz;
}T[5000050];

const int maxn=10000050;
int n,op,x,root,nth;

void insert(int l,int r,int &no){
    if(l>x||r<x)return;
    if(no==0)no=++nth;
    ++T[no].siz;
    if(l==r)return;
    int mid=(l+r)>>1;
    insert(l,mid,T[no].left);
    insert(mid+1,r,T[no].right);
}

void del(int l,int r,int no){
    if(l>x||r<x||no==0)return;
    if(l==r){
        --T[no].siz;
        return;
    }
    int mid=(l+r)>>1;
    del(l,mid,T[no].left);
    del(mid+1,r,T[no].right);
    T[no].siz=T[T[no].left].siz+T[T[no].right].siz;
}

int Rank(int l,int r,int no){
    if(l>=x||no==0)return 0;
    if(r<x)return T[no].siz;
    int mid=(l+r)>>1;
    return Rank(l,mid,T[no].left)+Rank(mid+1,r,T[no].right);
}

int kth(int l,int r,int no,int rk){
    if(l==r)return l;
    int mid=(l+r)>>1;
    if(T[T[no].left].siz>=rk)return kth(l,mid,T[no].left,rk);
    return kth(mid+1,r,T[no].right,rk-T[T[no].left].siz);
}

int lst(){
    return kth(-maxn,maxn,root,Rank(-maxn,maxn,root));
}

int nxt(){
    ++x;
    return kth(-maxn,maxn,root,Rank(-maxn,maxn,root)+1);
}

int main()
{
    cin>>n;
    while(n--){
        cin>>op>>x;
        switch(op){
            case 1:{
                insert(-maxn,maxn,root);
                break;
            }
            case 2:{
                del(-maxn,maxn,root);
                break;
            }
            case 3:{
                cout<<Rank(-maxn,maxn,root)+1<<endl;
                break;
            }
            case 4:{
                cout<<kth(-maxn,maxn,root,x)<<endl;
                break;
            }
            case 5:{
                cout<<lst()<<endl;
                break;
            }
            case 6:{
                cout<<nxt()<<endl;
                break;
            }
            default:break;
        }
    }
    return 0;
}