Treap

· · 算法·理论

BST

在普通二叉树的基础上多一个条件 对于 $p\in L$,满足 $v_p\leq v_x

对于 q\in R,满足 v_x<v_q

Treap

但是这样如果 BST 是一条链的话就退化 O(n),而且很容易卡,考虑Treap

Treap 就是在普通 BST 基础上加上随机点权,随机点权满足堆性质,这样通过随机的堆可以保证树的均匀,即时间复杂度稳定

旋转 Treap

#include<bits/stdc++.h>

using namespace std;

struct node {
    int ls, rs;
    int val, dat;
    int cnt, siz;
    #define ls(p) tr[p].ls
    #define rs(p) tr[p].rs
    #define v(p) tr[p].val
    #define d(p) tr[p].dat
    #define c(p) tr[p].cnt
    #define s(p) tr[p].siz
};

const int INF=1<<30;
struct treap {
    int top, rt;
    vector<node> tr;
    int nnew(int val) {
        v(++top)=val;
        d(top)=rand();
        c(top)=s(top)=1;
        return top;
    }
    void pushup(int p) {
        s(p)=s(ls(p))+s(rs(p))+c(p);
    }
    void build(int siz) {
        tr.resize(siz+1);
        rt=nnew(-INF);
        rs(rt)=nnew(INF);
        pushup(rt);
    }
    int getrk(int p,int val) {
        if(p==0) return 0;
        if(val==v(p)) return s(ls(p))+1;
        if(val<v(p)) return getrk(ls(p),val);
        return getrk(rs(p),val)+s(ls(p))+c(p);
    }
    int getrk(int val) {
        return getrk(rt,val)-1;
    }
    int getval(int p,int rk) {
        if(p==0) return INF;
        if(rk<=s(ls(p))) return getval(ls(p),rk);
        if(rk<=s(ls(p))+c(p)) return v(p);
        return getval(rs(p),rk-s(ls(p))-c(p)); 
    }
    int getval(int rk) {
        return getval(rt,rk+1);
    }
    void uplt(int &p) {
        int q=ls(p);
        ls(p)=rs(q), rs(q)=p, p=q;
        pushup(rs(p)), pushup(p);
    }
    void uprt(int &p) {
        int q=rs(p);
        rs(p)=ls(q), ls(q)=p, p=q;
        pushup(ls(p)), pushup(p);
    }
    void insert(int &p,int val) {
        if(p==0) {
            p=nnew(val);
            return;
        }
        if(val==v(p)) {
            c(p)++, pushup(p);
            return;
        }
        if(val<v(p)) {
            insert(ls(p),val);
            if(d(p)<d(ls(p))) uplt(p);
        }
        else {
            insert(rs(p),val);
            if(d(p)<d(rs(p))) uprt(p);
        }
        pushup(p);
    }
    void insert(int val) {
        insert(rt,val);
    }
    int getpre(int val) {
        int ans=1, p=rt;
        for( ; p; p=val<v(p)?ls(p):rs(p)) {
            if(val==v(p)) {
                if(ls(p)) {
                    for(p=ls(p); rs(p); p=rs(p));
                    ans=p;
                }
                break;
            }
            if(v(p)<val&&v(p)>v(ans)) ans=p;
        }
        return v(ans);
    }
    int getnxt(int val) {
        int ans=2, p=rt;
        for( ; p; p=val<v(p)?ls(p):rs(p)) {
            if(v(p)==val) {
                if(rs(p)) {
                    for(p=rs(p); ls(p); p=ls(p));
                    ans=p;
                }
                break;
            }
            if(v(p)>val&&v(p)<v(ans)) ans=p;
        }
        return v(ans);
    }
    void remove(int &p,int val) {
        if(p==0) return;
        if(val==v(p)) {
            if(c(p)>1) {
                c(p)--, pushup(p);
                return;
            } 
            if(ls(p)||rs(p)) {
                if(rs(p)==0||d(ls(p))>d(rs(p)))
                    uplt(p), remove(rs(p),val);
                else uprt(p), remove(ls(p),val);
                pushup(p); 
            }
            else p=0;
            return;
        }
        val<v(p)?remove(ls(p),val):remove(rs(p),val);
        pushup(p);
    }
    void remove(int val) {
        remove(rt,val);
    }
} qwq;
#include<bits/stdc++.h>
#define int long long
#define MN 101000 

using namespace std;

struct node {
    int ls, rs;
    int val, dat;
    int cnt, siz;
    #define ls(p) tr[p].ls
    #define rs(p) tr[p].rs
    #define v(p) tr[p].val
    #define d(p) tr[p].dat
    #define c(p) tr[p].cnt
    #define s(p) tr[p].siz
} tr[MN]; 

const int INF=(int)1<<60;
int top, rt;

int nnew(int val) {
    v(++top)=val;
    d(top)=rand();
    c(top)=s(top)=1;
    return top;
}

void pushup(int p) {
    s(p)=s(ls(p))+s(rs(p))+c(p);
}

void build() {
    rt=nnew(-INF);
    rs(rt)=nnew(INF);
    pushup(rt);
}

int getrk(int p,int val) {
    if(p==0) return 0;
    if(val==v(p)) return s(ls(p))+1;
    if(val<v(p)) return getrk(ls(p),val);
    return getrk(rs(p),val)+s(ls(p))+c(p);
}

int getval(int p,int rk) {
    if(p==0) return INF;
    if(rk<=s(ls(p))) return getval(ls(p),rk);
    if(rk<=s(ls(p))+c(p)) return v(p);
    return getval(rs(p),rk-s(ls(p))-c(p)); 
}

void uplt(int &p) {
    int q=ls(p);
    ls(p)=rs(q), rs(q)=p, p=q;
    pushup(rs(p)), pushup(p);
}

void uprt(int &p) {
    int q=rs(p);
    rs(p)=ls(q), ls(q)=p, p=q;
    pushup(ls(p)), pushup(p);
}

void insert(int &p,int val) {
    if(p==0) {
        p=nnew(val);
        return;
    }
    if(val==v(p)) {
        c(p)++, pushup(p);
        return;
    }
    if(val<v(p)) {
        insert(ls(p),val);
        if(d(p)<d(ls(p))) uplt(p);
    }
    else {
        insert(rs(p),val);
        if(d(p)<d(rs(p))) uprt(p);
    }
    pushup(p);
}

int getpre(int val) {
    int ans=1, p=rt;
    for( ; p; p=val<v(p)?ls(p):rs(p)) {
        if(val==v(p)) {
            if(ls(p)) {
                for(p=ls(p); rs(p); p=rs(p));
                ans=p;
            }
            break;
        }
        if(v(p)<val&&v(p)>v(ans)) ans=p;
    }
    return v(ans);
}

int getnxt(int val) {
    int ans=2, p=rt;
    for( ; p; p=val<v(p)?ls(p):rs(p)) {
        if(v(p)==val) {
            if(rs(p)) {
                for(p=rs(p); ls(p); p=ls(p));
                ans=p;
            }
            break;
        }
        if(v(p)>val&&v(p)<v(ans)) ans=p;
    }
    return v(ans);
}

void remove(int &p,int val) {
    if(p==0) return;
    if(val==v(p)) {
        if(c(p)>1) {
            c(p)--, pushup(p);
            return;
        } 
        if(ls(p)||rs(p)) {
            if(rs(p)==0||d(ls(p))>d(rs(p)))
                uplt(p), remove(rs(p),val);
            else uprt(p), remove(ls(p),val);
            pushup(p); 
        }
        else p=0;
        return;
    }
    val<v(p)?remove(ls(p),val):remove(rs(p),val);
    pushup(p);
}

signed main() {
    srand((unsigned)time(0));
    build(); int n;
    scanf("%lld", &n);
    while(n--) {
        int op, x;
        scanf("%lld%lld", &op, &x);
        if(op==1) insert(rt,x);
        if(op==2) remove(rt,x);
        if(op==3) printf("%lld\n", getrk(rt,x)-1);
        if(op==4) printf("%lld\n", getval(rt,x+1));
        if(op==5) printf("%lld\n", getpre(x));
        if(op==6) printf("%lld\n", getnxt(x));
    }
    return 0;
} 
#define int long long
#define MN 101000 
const int INF=(int)1<<60;
struct treap {
    struct node {
        int ls, rs;
        int val, dat;
        int cnt, siz;
        #define ls(p) tr[p].ls
        #define rs(p) tr[p].rs
        #define v(p) tr[p].val
        #define d(p) tr[p].dat
        #define c(p) tr[p].cnt
        #define s(p) tr[p].siz
    } tr[MN]; 
    int top, rt;
    int nnew(int val) {
        v(++top)=val;
        d(top)=rand();
        c(top)=s(top)=1;
        return top;
    }
    void pushup(int p) {
        s(p)=s(ls(p))+s(rs(p))+c(p);
    }
    void build() {
        rt=nnew(-INF);
        rs(rt)=nnew(INF);
        pushup(rt);
    }
    int getrk(int p,int val) {
        if(p==0) return 0;
        if(val==v(p)) return s(ls(p))+1;
        if(val<v(p)) return getrk(ls(p),val);
        return getrk(rs(p),val)+s(ls(p))+c(p);
    }
    int getval(int p,int rk) {
        if(p==0) return INF;
        if(rk<=s(ls(p))) return getval(ls(p),rk);
        if(rk<=s(ls(p))+c(p)) return v(p);
        return getval(rs(p),rk-s(ls(p))-c(p)); 
    }
    void uplt(int &p) {
        int q=ls(p);
        ls(p)=rs(q), rs(q)=p, p=q;
        pushup(rs(p)), pushup(p);
    }
    void uprt(int &p) {
        int q=rs(p);
        rs(p)=ls(q), ls(q)=p, p=q;
        pushup(ls(p)), pushup(p);
    }
    void insert(int &p,int val) {
        if(p==0) {
            p=nnew(val);
            return;
        }
        if(val==v(p)) {
            c(p)++, pushup(p);
            return;
        }
        if(val<v(p)) {
            insert(ls(p),val);
            if(d(p)<d(ls(p))) uplt(p);
        }
        else {
            insert(rs(p),val);
            if(d(p)<d(rs(p))) uprt(p);
        }
        pushup(p);
    }
    int getpre(int val) {
        int ans=1, p=rt;
        for( ; p; p=val<v(p)?ls(p):rs(p)) {
            if(val==v(p)) {
                if(ls(p)) {
                    for(p=ls(p); rs(p); p=rs(p));
                    ans=p;
                }
                break;
            }
            if(v(p)<val&&v(p)>v(ans)) ans=p;
        }
        return v(ans);
    }
    int getnxt(int val) {
        int ans=2, p=rt;
        for( ; p; p=val<v(p)?ls(p):rs(p)) {
            if(v(p)==val) {
                if(rs(p)) {
                    for(p=rs(p); ls(p); p=ls(p));
                    ans=p;
                }
                break;
            }
            if(v(p)>val&&v(p)<v(ans)) ans=p;
        }
        return v(ans);
    }
    void remove(int &p,int val) {
        if(p==0) return;
        if(val==v(p)) {
            if(c(p)>1) {
                c(p)--, pushup(p);
                return;
            } 
            if(ls(p)||rs(p)) {
                if(rs(p)==0||d(ls(p))>d(rs(p)))
                    uplt(p), remove(rs(p),val);
                else uprt(p), remove(ls(p),val);
                pushup(p); 
            }
            else p=0;
            return;
        }
        val<v(p)?remove(ls(p),val):remove(rs(p),val);
        pushup(p);
    }
} ;
#include <cstdlib>
#include <cstdio>

int n;

struct NODE {
    int value, fix;
    NODE *left, *right;

    NODE(const int value) : value(value) {
        fix = rand();
        left = NULL;
        right = NULL;
    }
};

void rightRotate(NODE *&p) {
    NODE *tmp = p->left;
    p->left = tmp->right;
    tmp->right = p;

    p = tmp;
}

void leftRotate(NODE *&p) {
    NODE *tmp = p->right;
    p->right = tmp->left;
    tmp->left = p;

    p = tmp;
}

void insert(NODE *&p, const int value) {
    if (p == NULL) {
        p = new NODE(value);
    }
    else if (value <= p->value) {
        insert(p->left, value);
        if (p->left->fix < p->fix)
            rightRotate(p);
    }
    else {
        insert(p->right, value);
        if (p->right->fix < p->fix)
            leftRotate(p);
    }
}

int count(const NODE *p, const int value) {
    if (!p) return 0;
    if (p->value == value) return 1;

    if (value <= p->value)
        return count(p->left, value);
    else
        return count(p->right, value);
}

void remove(NODE *&p, const int value) {
    if (!p) return;
    if (p->value == value) {
        if (p->left == NULL || p->right == NULL) {
            NODE *tmp = p;
            if (p->right) p = p->right;
            else p = p->left;
            delete tmp;
        }
        else if (p->left->fix < p->right->fix) {
            rightRotate(p);
            remove(p->right, value);
        }
        else {
            leftRotate(p);
            remove(p->left, value);
        }
    }
    else if (value < p->value)
        remove(p->left, value);
    else
        remove(p->right, value);
}

NODE *root;

int main() {
    // srand(...);

    scanf("%d", &n);
    while (n --) {
        int opt, x;
        scanf("%d %d", &opt, &x);

        if (opt == 1) {
            insert(root, x);
        }
        else if (opt == 2) {
            remove(root, x);
        }
        else {
            printf("%d\n", count(root, x));
        }
    }

    return 0;
}

非旋 Treap

#include <utility>
#include <cstdlib>
#include <cstdio>
using std::make_pair;
using std::pair;

struct NODE {
    int value, fix;
    NODE *left, *right;

    NODE(const int value) : value(value) {
        fix = rand();
        left = NULL;
        right = NULL;
    }
};

pair<NODE*, NODE*> split(NODE *p, const int value) {
    if (!p)
        return make_pair((NODE*)NULL, (NODE*)NULL);

    if (p->value <= value) {
        pair<NODE*, NODE*> tmp = split(p->right, value);
        p->right = tmp.first;
        return make_pair(p, tmp.second);
    }
    else {
        pair<NODE*, NODE*> tmp = split(p->left, value);
        p->left = tmp.second;
        return make_pair(tmp.first, p);
    }
}

NODE* merge(NODE *l, NODE *r) {
    if (l == NULL) return r;
    if (r == NULL) return l;
    if (l->fix < r->fix) {
        l->right = merge(l->right, r);
        return l;
    }
    else {
        r->left = merge(l, r->left);
        return r;
    }
}

void insert(NODE *&p, const int value) {
    pair<NODE*, NODE*> tmp = split(p, value);
    p = merge(merge(tmp.first, new NODE(value)), tmp.second);
}

void remove_one(NODE *&p) {
    if (p->left == NULL || p->right == NULL) {
        NODE *t = p;
        if (p->right) p = p->right;
        else p = p->left;
        delete t;
    }
    else remove_one(p->left);
}

void remove(NODE *&p, const int value) {
    pair<NODE*, NODE*> tmp = split(p, value);
    pair<NODE*, NODE*> tmp2 = split(tmp.first, value - 1);

    remove_one(tmp2.second);
    p = merge(merge(tmp2.first, tmp2.second), tmp.second);
}

int count(const NODE *p, const int value) {
    if (!p) return 0;
    if (p->value == value) return 1;

    if (value <= p->value)
        return count(p->left, value);
    else
        return count(p->right, value);
}

NODE *root;
int n;

int main() {
    // srand(...);

    scanf("%d", &n);
    while (n --) {
        int opt, x;
        scanf("%d %d", &opt, &x);

        if (opt == 1) {
            insert(root, x);
        }
        else if (opt == 2) {
            remove(root, x);
        }
        else {
            printf("%d\n", count(root, x));
        }
    }

    return 0;
}