SL学习笔记

· · 个人记录

#include<bits/stdc++.h>
using namespace std;
mt19937 rnd(time(nullptr));

const int N = 1e5 + 10;
const int M = 15;
const int inf = 0x7f7f7f7f;
int n;
typedef struct chain{
    int val,dto[M + 10];
    struct chain* nxt[M + 10];
}* cp;
cp head,upd[M + 10];
int lv,sz,tot[M + 10];

void init(){
    lv = 1,sz = 0;
    head = new(chain);
    head->val = -inf;
    for(int i = M;i >= 1;-- i){
        head->nxt[i] = nullptr;
        upd[i] = head;
    }
    return;
}

int rndlvl(){
    int x = 1;
    for(;(rnd() & 1) && x < M;++ x);
    return x;
}

int rnk(int x){
    cp p = head;
    tot[lv + 1] = 0;
    for(int i = lv;i >= 1;-- i){
        tot[i] = tot[i + 1];
        while(p->nxt[i] && p->nxt[i]->val < x)
            tot[i] += p->dto[i],p = p->nxt[i];
        upd[i] = p;
    }
    return tot[1] + 1;
}

void insert(int x){
    cp p = new(chain);
    p->val = x;
    rnk(x);
    int lay = rndlvl();
    lv = max(lay,lv);
    for(int i = 1;i <= lay;++ i){
        p->nxt[i] = upd[i]->nxt[i];
        p->dto[i] = upd[i]->dto[i] - (tot[1] - tot[i]);
        upd[i]->nxt[i] = p;
        upd[i]->dto[i] = tot[1] - tot[i] + 1;
    }
    for(int i = lay + 1;i <= lv;++ i) ++ upd[i]->dto[i];
    ++ sz;
    return;
}

void remove(int x){
    rnk(x);
    cp p = upd[1]->nxt[1];
    for(int i = 1;i <= lv;++ i){
        if(upd[i]->nxt[i] == p){
            upd[i]->dto[i] += p->dto[i] - 1;
            upd[i]->nxt[i] = p->nxt[i];
        }else -- upd[i]->dto[i];
    }
    for(;lv > 1 && ! head->nxt[lv];-- lv);
    delete(p);
    -- sz;
    return;
}

int kth(int k){
    cp p = head;
    for(int i = lv;i >= 1;-- i)
        while(p->nxt[i] && p->dto[i] < k)
            k -= p->dto[i],p = p->nxt[i];
    return p->nxt[1]->val;
}

int getpre(int x){
    cp p = head;
    for(int i = lv;i >= 1;-- i)
        while(p->nxt[i] && p->nxt[i]->val < x)
            p = p->nxt[i];
    if(p->val != inf) return p->val;
    return -inf;
}

int getnxt(int x){
    cp p = head;
    for(int i = lv;i >= 1;-- i)
        while(p->nxt[i] && p->nxt[i]->val <= x)
            p = p->nxt[i];
    if(p->nxt[1]->val) return p->nxt[1]->val;
    return inf;
}
int main(){
    ios::sync_with_stdio(0);
    cin.tie(nullptr);
    init();
    cin >> n;
    while(n --){
        int op,x;
        cin >> op >> x;
        if(op == 1)
            insert(x);
        if(op == 2)
            remove(x);
        if(op == 3)
            cout << rnk(x) << '\n';
        if(op == 4)
            cout << kth(x) << '\n';
        if(op == 5)
            cout << getpre(x) << '\n';
        if(op == 6)
            cout << getnxt(x) << '\n';
    }
    return 0;
}