【学习笔记】Splay

· · 算法·理论

\texttt{0x01} 前言

Splay 树(伸展树)由 Daniel SleatorRobert Tarjan 于 1985 年发明。它凭借旋转可以有 O(\log n) 插入,删除等的较优秀的时间复杂度。

前置芝士:普通二叉排序树。

推荐博客:

\texttt{0x02} 如何构造一棵 Splay

我们定义一个结构体:

#define val(x) t[x].val
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
#define son(x,nxt) t[x].ch[nxt]
#define fa(x) t[x].fa
#define cnt(x) t[x].cnt
#define siz(x) t[x].siz
struct node{
    int val,fa,ch[2],siz,cnt;
}t[N];
int root,tot;

其中构造一个新节点的函数长这样:

void newPoint(int val,int fa,int nxt){ //值为val,父节点为fa,为fa的nxt儿子
    tot++;
    fa(tot)=fa; cnt(tot)=siz(tot)=1; val(tot)=val;
    son(fa,nxt)=tot;
}

\texttt{0x03} which / pushup / connect

which 的作用是判断 x 是其父节点的左节点(0)还是右节点(1),代码很好写:

bool which(int x){
    return rs(fa(x))==x;
}

pushup 的作用是维护当前节点的 siz 信息,和线段树的 pushup 性质差不多,代码:

void pushup(int x){
    siz(x)=siz(ls(x))+siz(rs(x))+cnt(x);//记得加上当前节点的cnt
}

connect 的作用是把 x 变成 ynxt 儿子,无需考虑覆盖的问题,代码也很简洁:

void connect(int x,int y,int nxt){
    son(y,nxt)=x;
    fa(x)=y;
}

\texttt{0x04} rotate

Splay 的核心操作:旋转。

放两张动图:

我们会发现:右旋时,E 节点要到 S 节点的位置上,那么 E 节点的右儿子因为它 \ge E\le S,所以只能放在 S 节点的左儿子,然后要改变 E 和 S 的父子关系。最后别忘了因为有旋转,所以要自下而上更新节点信息。

左旋同理。

代码:

void rotate(int x){
    int y=fa(x),z=fa(y);
    int fx=which(x),fy=which(y);

    connect(son(x,fx^1),y,fx); //如果x是左儿子,改变它右儿子的位置,反之同理
    connect(y,x,fx^1); //把y接到x的缺失的那一棵子树上
    connect(x,z,fy); //把x接到y的父节点上去

    pushup(y); pushup(x); //别搞错顺序
}

\texttt{0x05} splay

Splay 树保证时间复杂度正确的核心操作,把 x 转到 y 的位置(y 通常为 root)。

有几点结论,难证但好记:

void splay(int x,int y){
    y=fa(y); //避免x=y时出现的错误
    while(fa(x)!=y){
        if(fa(fa(x))==y) // Case 1
            rotate(x);
        else if(which(x)==which(fa(x))) // Case 2
            rotate(fa(x)), rotate(x);
        else // Case 3
            rotate(x), rotate(x);
    }
    if(y==0){ // 如果y是根,把根变为x
        root=x;
        connect(x,0,1);
    }
}

\texttt{0x06} insert

与普通的二叉排序树基本一致。

记得最后要 splay 一下,把这个点转到根节点。

void insert(int val){
    if(root==0){
        newPoint(val,0,1);
        root=tot;
        return;
    }
    int now=root;
    while(1){
        siz(now)++;
        if(val(now)==val){
            cnt(now)++;
            splay(now,root);
            return;
        }
        int nxt=val(now)<val, son=son(now,nxt);
        if(!son){
            newPoint(val,now,nxt);
            splay(tot,root);
            return;
        }
        now=son;
    }
}

\texttt{0x07} find

这一步操作是找到树中值为 val 的节点,并把它旋转到根节点,为 delete 操作做准备。

与普通二叉排序树也基本一致。

int find(int val){
    int now=root;
    while(1){
        if(!now)
            return 0;
        if(val(now)==val){
            splay(now,root);
            return now;
        }
        int nxt=val(now)<val, son=son(now,nxt);
        now=son;
    }
}

\texttt{0x08} delete

目的是删除树中值为 val 的节点。

先 find 这个节点,让他转到根,然后分类讨论。

  1. 树中没有值为 val 的节点,删了个寂寞。

  2. 树中值为 val 的节点有不止一个(即 cnt \ge 2),让 cnt \gets cnt-1 即可。

  3. 这个节点没有左儿子(即根节点只有右子树),把右儿子设为根就行了。

  4. 这个节点没有右儿子(即根节点只有左子树),把左儿子设为根就行了。

  5. 这个节点(设为 x)有左右儿子,把它的左子树中值最大的(设为 y)splay 到根,然后现在的 Splay 树的根就是 y,左子树是原来的除 y 之外的左子树,右子树是 x 和之前的右子树。把之前的右子树 connect 到根就行了。

注意:这里的删除操作都没有回收编号

void delet(int val){
    int now=find(val);
    if(!now) return;
    if(cnt(now)>1){
        cnt(now)--; siz(now)--;
        return;
    }
    if(!ls(now) && !rs(now)){
        root=0;
    }
    else if(!ls(now)){
        root=rs(root);
        fa(root)=0;
    }
    else if(!rs(now)){
        root=ls(root);
        fa(root)=0;
    }
    else{
        int pos=ls(now);
        while(rs(pos)) pos=rs(pos);

        splay(pos,root);
        connect(rs(now),pos,1);
        pushup(pos);
    }
}

\texttt{0x09} rnk & find_k

rnk 是返回值为 val 的数在树中的排名,find_k 是找到树中排名为 k 的数。

与二叉排序树基本完全相同。记得最后要把节点 splay 到根。

int rnk(int val){
    int now=root,s=0;
    while(now){
        if(val(now)==val){
            splay(now,root);
            return siz(ls(now))+1;
        }
        if(val(now)<val){
            s+=siz(ls(now))+cnt(now);
            now=rs(now);
        }
        else{
            now=ls(now);
        }
    }
    return s+1;
}
int find_k(int k){
    int now=root;
    while(1){
        int used=siz(now)-siz(rs(now));
        if(k>siz(ls(now)) && k<=used){
            break;
        }
        if(k>=used){
            k-=used;
            now=rs(now);
        }
        else{
            now=ls(now);
        }
    }
    splay(now,root);
    return val(now);
}

\texttt{0x0A} lower & upper

返回值为 val 的数的前驱和后继。

int lower(int val){
    int ans=-2147483647;
    int now=root;
    while(now){
        if(val(now)<val && val(now)>ans){
            ans=val(now);
        }
        if(val>val(now)){
            now=rs(now);
        }
        else{
            now=ls(now);
        }
    }
    return ans;
}
int upper(int val){
    int ans=2147483647;
    int now=root;
    while(now){
        if(val(now)>val && val(now)<ans){
            ans=val(now);
        }
        if(val<val(now)){
            now=ls(now);
        }
        else{
            now=rs(now);
        }
    }
    return ans;
}

\texttt{0x0B} 完整代码

#include<bits/stdc++.h>
using namespace std;
constexpr int N = 5e5+5;

template <typename T> void read(T &x){x=0; T f(0); char ch=getchar(); while(ch<'0'||ch>'9'){f|=ch=='-';ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=getchar();} x=f?-x:x;}
template <typename T,typename ...Arg>void read(T& x,Arg& ...arg){read(x);read(arg...);}
template <typename T> inline void write(T x){static char buf[64]; static int tot(0); if(x<0) putchar('-'),x=-x; do buf[++tot]=(x%10)+48,x/=10; while(x); do putchar(buf[tot--]); while(tot);}
template <typename T> void write(T x,char c){static char buf[64]; static int tot(0); if(x<0) putchar('-'),x=-x; do buf[++tot]=(x%10)+48,x/=10; while(x); do putchar(buf[tot--]); while(tot); putchar(c);}

class Splay{
#define val(x) t[x].val
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
#define son(x,nxt) t[x].ch[nxt]
#define fa(x) t[x].fa
#define cnt(x) t[x].cnt
#define siz(x) t[x].siz
private:
    struct node{
        int val,fa,ch[2],siz,cnt;
    }t[N];
    int root,tot;
public:
    bool which(int x){
        return rs(fa(x))==x;
    }
    void pushup(int x){
        siz(x)=siz(ls(x))+siz(rs(x))+cnt(x);
    }
    void connect(int x,int y,int nxt){
        son(y,nxt)=x;
        fa(x)=y;
    }
    void rotate(int x){
        int y=fa(x),z=fa(y);
        int fx=which(x),fy=which(y);

        connect(son(x,fx^1),y,fx);
        connect(y,x,fx^1);
        connect(x,z,fy);

        pushup(y); pushup(x);
    }
    void splay(int x,int y){
        y=fa(y);
        while(fa(x)!=y){
            if(fa(fa(x))==y)
                rotate(x);
            else if(which(x)==which(fa(x)))
                rotate(fa(x)), rotate(x);
            else
                rotate(x), rotate(x);
        }
        if(y==0){
            root=x;
            connect(x,0,1);
        }
    }
    void newPoint(int val,int fa,int nxt){
        tot++;
        fa(tot)=fa; cnt(tot)=siz(tot)=1; val(tot)=val;
        son(fa,nxt)=tot;
    }
    void insert(int val){
        if(root==0){
            newPoint(val,0,1);
            root=tot;
            return;
        }
        int now=root;
        while(1){
            siz(now)++;
            if(val(now)==val){
                cnt(now)++;
                splay(now,root);
                return;
            }
            int nxt=val(now)<val, son=son(now,nxt);
            if(!son){
                newPoint(val,now,nxt);
                splay(tot,root);
                return;
            }
            now=son;
        }
    }
    int find(int val){
        int now=root;
        while(1){
            if(!now)
                return 0;
            if(val(now)==val){
                splay(now,root);
                return now;
            }
            int nxt=val(now)<val, son=son(now,nxt);
            now=son;
        }
    }
    void delet(int val){
        int now=find(val);
        if(!now) return;
        if(cnt(now)>1){
            cnt(now)--; siz(now)--;
            return;
        }
        if(!ls(now) && !rs(now)){
            root=0;
        }
        else if(!ls(now)){
            root=rs(root);
            fa(root)=0;
        }
        else if(!rs(now)){
            root=ls(root);
            fa(root)=0;
        }
        else{
            int pos=ls(now);
            while(rs(pos)) pos=rs(pos);

            splay(pos,root);
            connect(rs(now),pos,1);
            pushup(pos);
        }
    }
    int rnk(int val){
        int now=root,s=0;
        while(now){
            if(val(now)==val){
                splay(now,root);
                return siz(ls(now))+1;
            }
            if(val(now)<val){
                s+=siz(ls(now))+cnt(now);
                now=rs(now);
            }
            else{
                now=ls(now);
            }
        }
        return s+1;
    }
    int find_k(int k){
        int now=root;
        while(1){
            int used=siz(now)-siz(rs(now));
            if(k>siz(ls(now)) && k<=used){
                break;
            }
            if(k>=used){
                k-=used;
                now=rs(now);
            }
            else{
                now=ls(now);
            }
        }
        splay(now,root);
        return val(now);
    }
    int lower(int val){
        int ans=-2147483647;
        int now=root;
        while(now){
            if(val(now)<val && val(now)>ans){
                ans=val(now);
            }
            if(val>val(now)){
                now=rs(now);
            }
            else{
                now=ls(now);
            }
        }
        return ans;
    }
    int upper(int val){
        int ans=2147483647;
        int now=root;
        while(now){
            if(val(now)>val && val(now)<ans){
                ans=val(now);
            }
            if(val<val(now)){
                now=ls(now);
            }
            else{
                now=rs(now);
            }
        }
        return ans;
    }
}tr;

int n,opt,val;

int main(){
    read(n);
    while(n--){
        read(opt,val);
        int ans;
        switch(opt){
            case 1:{
                tr.insert(val);
                break;
            }
            case 2:{
                tr.delet(val);
                break;
            }
            case 3:{
                ans=tr.rnk(val);
                break;
            }
            case 4:{
                ans=tr.find_k(val);
                break;
            }
            case 5:{
                ans=tr.lower(val);
                break;
            }
            case 6:{
                ans=tr.upper(val);
                break;
            }
        }
        if(opt>2) write(ans,'\n');
    }
}