旋转 Treap

· · 个人记录

宏定义

常量与变量

函数

#define pl a[p].l
#define pr a[p].r
#define pv a[p].val
#define pc a[p].cnt
#define ps a[p].size
#define pd a[p].dat
struct Treap{
    const TT INF=0x7fffffff;
    int tot,root,nc[N];
    struct Tree{
        int l,r;
        TT val;
        int dat,cnt,size;
    }a[N];
    int compare(TT x,TT y){
        if(x<y)
            return -1;
        if(x==y)
            return 0;
        if(x>y)
            return 1;
    }
    void update(int p){
        ps=a[pl].size+a[pr].size+pc;
    }
    int get_new(TT val){
        int p=nc[++tot];
        pv=val;
        pc=ps=1;
        pl=pr=0;
        pd=rand();
        return p;
    }
    void del(int &p){
        nc[tot--]=p;
        p=0;
    }
    void build(){
        tot=0;
        for(int i=1;i<N;i++)
            nc[i]=i;
        get_new(-INF);get_new(INF);
        a[root=1].r=2;
        update(root);
    }
    int get(int p,TT val){
        if(!p)
            return 0;
        if(!compare(val,pv))
            return p;
        return get((compare(val,pv)<0?pl:pr),val);
    }
    void zig(int &p){
        int q=pl;
        pl=a[q].r;a[q].r=p;p=q;
        update(pr);update(p);
    }
    void zag(int &p){
        int q=a[p].r;
        pr=a[q].l;a[q].l=p;p=q;
        update(pl);update(p);
    }
    void insert(int &p,TT val){
        if(!p){
            p=get_new(val);
            return;
        }
        if(!compare(val,pv)){
            pc++;
            update(p);
            return;
        }
        if((compare(val,pv)<0)){
            insert(pl,val);
            if(pd<a[pl].dat)
                zig(p);
        }
        else{
            insert(pr,val);
            if(pd<a[pr].dat)
                zag(p);
        }
        update(p);
    }
    void remove(int &p,TT val){
        if(p==0)
            return;
        if(!compare(val,pv)){
            if(pc>1){
                pc--;
                update(p);
                return;
            }
            if(pl||pr){
                if(!pr||a[pl].dat>a[pr].dat){
                    zig(p);
                    remove(pr,val);
                }
                else{
                    zag(p);
                    remove(pl,val);
                }
                update(p);
                return;
            }
            del(p);
            return;
        }
        if((compare(val,pv)<0)) 
            remove(pl,val);
        else
            remove(pr,val);
        update(p);
    }
    int get(int p,int val){
        if(!p)
            return 0;
        if(val==pv)
            return p;
        return get((val<pv?pl:pr),val);
    }
    int get_pre(TT val){
        int ans=1,p=root;
        while(p){
            if(!compare(val,pv)){
                if(pl){
                    p=pl;
                    while(pr)
                        p=pr;
                    ans=p; 
                }
                break;
            }
            if((compare(pv,val)<0&&compare(pv,a[ans].val)>0))
                ans=p;
            p=(compare(val,pv)<0?pl:pr);
        }
        return ans;
    }
    int get_next(TT val){
        int ans=2,p=root;
        while(p){
            if(!compare(val,pv)){
                if(pr){
                    p=pr;
                    while(pl)
                        p=pl;
                    ans=p;
                }
                break;
            }
            if((compare(pv,val)>0&&compare(pv,a[ans].val)<0))
                ans=p;
            p=(compare(val,pv)<0?pl:pr);
        }
        return ans;
    }
    int V_to_R(int p,TT val){//val to rank
        if(!p)
            return 0;
        if(!compare(val,pv))
            return a[pl].size;
        if(compare(val,pv)<0)
            return V_to_R(pl,val);
        return V_to_R(pr,val)+a[pl].size+pc;
    }
    TT R_to_V(int p,int rank){//rank to val
        if(!p)
            return INF;
        if(a[pl].size>=rank)
            return R_to_V(pl,rank);
        if(a[pl].size+pc>=rank)
            return pv;
        return R_to_V(pr,rank-a[pl].size-pc);
    }
    void write(int p){
        if(!p)
            return;
        printf("%d  val:%d  lson:%d %d  rson:%d %d\n",p,pv,pl,a[pl].val,pr,a[pr].val);
        write(pl);write(pr);
    }
}tree;