Splay树(伸展树)实现详解

· · 算法·理论

Splay树(伸展树)实现详解

1. 数据结构定义

const int N=1e5;
int rt;//根节点
int val[N],cnt[N],siz[N];//数据域:值、计数、子树大小
int fa[N],son[N][2];//指针域:父节点、左右儿子

实现原理

2. 辅助函数

2.1 方向判断函数

//判断节点x的方位,返回0为左,1为右
int dir(int x){
    return son[fa[x]][1]==x;
}

实现原理

2.2 更新函数

//更新节点x的大小
void up(int x){
    siz[x]=siz[son[x][0]]+siz[son[x][1]]+cnt[x];
}

实现原理

3. 核心操作

3.1 旋转操作

//旋转操作
void rotate(int x){
    int y=fa[x],z=fa[y],s=dir(x);//y是x的父节点,z是y的父节点,s是x的方向

    //处理x的与旋转方向相反的子节点
    if(son[x][!s]){
        fa[son[x][!s]]=y;//将x的反向子节点连接到y上
    }
    son[y][s]=son[x][!s];//y继承x的反向子节点

    //处理x与y的关系
    son[x][!s]=y;//x成为y的父节点
    fa[y]=x;//y的父节点指向x

    //处理x与z的关系
    fa[x]=z;//x的父节点指向z
    if(z){
        son[z][son[z][1]==y]=x;//z的对应子节点指向x
    }
    up(y);//先更新y,因为y在x的下层
    up(x);//再更新x
}

实现原理

4. 伸展操作

4.1 伸展到根节点

//伸展操作:将节点x伸展到根节点
void splay_small(int x){
    while(fa[x]){//一直旋转直到x成为根节点
        int y=fa[x];//x的父节点
        if(fa[y]){//如果y不是根节点,考虑双旋
            if(dir(x)==dir(y)){
                rotate(y);//一字型情况,先旋转y
            }else{
                rotate(x);//之字形情况,先旋转x
            }
        }
        rotate(x);//最后旋转x
    }
    rt=x;//更新根节点
}

4.2 伸展到指定位置

//伸展操作:将节点x伸展到p的位置
void splay(int x,int &p){
    int z=fa[p];//目标位置的父节点
    while(fa[x]!=z){//一直旋转直到x的父节点是z
        int y=fa[x];//x的父节点
        if(fa[y]!=z){//如果y的父节点不是z,考虑双旋
            if(dir(x)==dir(y)){
                rotate(y);//一字型情况,先旋转y
            }else{
                rotate(x);//之字形情况,先旋转x
            }
        }
        rotate(x);//最后旋转x
    }
    p=x;//更新p指向x
}

实现原理

5. 基本BST操作

5.1 插入操作

//插入操作
void insert(int v){
    static int idx=0;//静态变量,节点计数器
    int x=rt,y=0;//x从根开始搜索,y记录父节点
    while(x && val[x]!=v){//查找插入位置
        x=son[y=x][v>val[x]];//根据大小关系选择左右子树
    }
    if(x){//如果节点已存在
        cnt[x]++;//增加计数
        siz[x]++;//更新大小
    }else{//如果节点不存在,创建新节点
        x=++idx;//分配新节点编号
        val[x]=v;//设置节点值
        cnt[x]=siz[x]=1;//初始化计数和大小
        fa[x]=y;//设置父节点
        if(y){//如果父节点存在
            son[y][v>val[y]]=x;//将新节点连接到父节点的对应位置
        }
    }
    splay(x,rt);//将新插入的节点伸展到根位置
}

实现原理

  1. 查找位置:从根开始,根据BST性质找到插入位置

  2. 处理重复:如果值已存在,增加计数

  3. 创建节点:如果值不存在,创建新节点并连接到树中

  4. 伸展优化:将新节点伸展到根,提高后续访问效率

5.2 查找操作

//查找值为v的节点并伸展到根
void find(int v){
    int x=rt;
    while(x && val[x]!=v){
        x=son[x][v>val[x]];
    }
    if(x) splay(x,rt);
}

实现原理

6. 查询操作

6.1 排名查询

//查询值v的排名(比v小的数的个数+1)
int rnk(int v){
    find(v);//先查找v并伸展到根
    if(val[rt]>=v){//如果根节点值>=v,排名在左子树中
        return siz[son[rt][0]]+1;
    }else{//如果根节点值<v,排名在右子树中
        return siz[son[rt][0]]+cnt[rt]+1;
    }
}

实现原理

6.2 第k小查询

//查询第k小的值
int kth(int k){
    int x=rt;
    while(x){
        if(k<=siz[son[x][0]]){//k在左子树中
            x=son[x][0];
        }else if(k<=siz[son[x][0]]+cnt[x]){//k在当前节点中
            splay(x,rt);//将找到的节点伸展到根
            return val[x];
        }else{//k在右子树中
            k-=siz[son[x][0]]+cnt[x];
            x=son[x][1];
        }
    }
    return -1;//未找到
}

实现原理

7. 前驱后继操作

7.1 前驱查询

//查找前驱(小于v的最大值)
int pre(int v){
    find(v);//先查找v并伸展到根
    if(val[rt]<v) return val[rt];//如果根节点值小于v,直接返回
    int x=son[rt][0];//否则在左子树中找最大值
    if(!x) return -1;//不存在前驱
    while(son[x][1]) x=son[x][1];//一直往右走
    splay(x,rt);//将前驱伸展到根
    return val[x];
}

7.2 后继查询

//查找后继(大于v的最小值)
int nxt(int v){
    find(v);//先查找v并伸展到根
    if(val[rt]>v) return val[rt];//如果根节点值大于v,直接返回
    int x=son[rt][1];//否则在右子树中找最小值
    if(!x) return -1;//不存在后继
    while(son[x][0]) x=son[x][0];//一直往左走
    splay(x,rt);//将后继伸展到根
    return val[x];
}

实现原理

8. 删除操作

//删除值为v的节点
void del(int v){
    find(v);//先查找v并伸展到根
    if(val[rt]!=v) return;//不存在该值

    if(cnt[rt]>1){//如果有多个相同值
        cnt[rt]--;
        siz[rt]--;
        return;
    }

    //只有一个节点的情况
    if(!son[rt][0] && !son[rt][1]){//没有子节点
        rt=0;//树为空
    }else if(!son[rt][0]){//只有右子树
        fa[son[rt][1]]=0;
        rt=son[rt][1];
    }else if(!son[rt][1]){//只有左子树
        fa[son[rt][0]]=0;
        rt=son[rt][0];
    }else{//有两个子节点
        int x=son[rt][0];
        while(son[x][1]) x=son[x][1];//在左子树中找最大值
        splay(x,son[rt][0]);//将x伸展到左子树的根

        //连接右子树
        son[x][1]=son[rt][1];
        fa[son[rt][1]]=x;

        //更新根
        fa[x]=0;
        rt=x;
        up(rt);//更新根节点大小
    }
}

实现原理

9. 测试示例

int main(){
    // 示例操作
    cout << "Splay树操作示例:" << endl;

    // 插入操作
    insert(5);
    insert(3);
    insert(7);
    insert(1);
    insert(9);
    insert(4);
    insert(6);

    cout << "插入 5, 3, 7, 1, 9, 4, 6 后:" << endl;

    // 查询排名
    cout << "数字4的排名: " << rnk(4) << endl; // 应该输出4
    cout << "数字1的排名: " << rnk(1) << endl; // 应该输出1

    // 查询第k小
    cout << "第3小的数: " << kth(3) << endl; // 应该输出4
    cout << "第5小的数: " << kth(5) << endl; // 应该输出6

    // 查询前驱和后继
    cout << "数字4的前驱: " << pre(4) << endl; // 应该输出3
    cout << "数字4的后继: " << nxt(4) << endl; // 应该输出5
    cout << "数字7的前驱: " << pre(7) << endl; // 应该输出6
    cout << "数字7的后继: " << nxt(7) << endl; // 应该输出9

    // 删除操作
    cout << "\\\\n删除数字4后:" << endl;
    del(4);

    cout << "数字4的排名: " << rnk(4) << endl; // 应该输出4(因为4被删除了,现在排名4的是5)
    cout << "第3小的数: " << kth(3) << endl;   // 应该输出5

    // 再次插入测试重复值
    cout << "\\\\n再次插入数字5(重复值):" << endl;
    insert(5);

    cout << "数字5的排名: " << rnk(5) << endl; // 应该输出4或5,取决于实现

    // 边界测试
    cout << "\\\\n边界测试:" << endl;
    cout << "最小值: " << kth(1) << endl;      // 应该输出1
    cout << "最大值: " << kth(6) << endl;      // 应该输出9(因为现在有6个节点)
    cout << "数字0的前驱: " << pre(0) << endl; // 应该输出-1(不存在)
    cout << "数字10的后继: " << nxt(10) << endl; // 应该输出-1(不存在)

    return 0;
}

10.无注释版本

#include<bits/stdc++.h>
using namespace std;

const int N=1e5;
int rt;
int val[N],cnt[N],siz[N];
int fa[N],son[N][2];

int dir(int x){
    return son[fa[x]][1]==x;
}

void up(int x){
    siz[x]=siz[son[x][0]]+siz[son[x][1]]+cnt[x];
}

void rotate(int x){
    int y=fa[x],z=fa[y],s=dir(x);
    if(son[x][!s]){
        fa[son[x][!s]]=y;
    }
    son[y][s]=son[x][!s];
    son[x][!s]=y;
    fa[y]=x;
    fa[x]=z;
    if(z){
        son[z][son[z][1]==y]=x;
    }
    up(y);
    up(x);
}

void splay_small(int x){
    while(fa[x]){
        int y=fa[x];
        if(fa[y]){
            if(dir(x)==dir(y)){
                rotate(y);
            }else{
                rotate(x);
            }
        }
        rotate(x);
    }
    rt=x;
}

void splay(int x,int &p){
    int z=fa[p];
    while(fa[x]!=z){
        int y=fa[x];
        if(fa[y]!=z){
            if(dir(x)==dir(y)){
                rotate(y);
            }else{
                rotate(x);
            }
        }
        rotate(x);
    }
    p=x;
}

void insert(int v){
    static int idx=0;
    int x=rt,y=0;
    while(x && val[x]!=v){
        x=son[y=x][v>val[x]];
    }
    if(x){
        cnt[x]++;
        siz[x]++;
    }else{
        x=++idx;
        val[x]=v;
        cnt[x]=siz[x]=1;
        fa[x]=y;
        if(y){
            son[y][v>val[y]]=x;
        }
    }
    splay(x,rt);
}

void find(int v){
    int x=rt;
    while(x && val[x]!=v){
        x=son[x][v>val[x]];
    }
    if(x) splay(x,rt);
}

int rnk(int v){
    find(v);
    if(val[rt]>=v){
        return siz[son[rt][0]]+1;
    }else{
        return siz[son[rt][0]]+cnt[rt]+1;
    }
}

int kth(int k){
    int x=rt;
    while(x){
        if(k<=siz[son[x][0]]){
            x=son[x][0];
        }else if(k<=siz[son[x][0]]+cnt[x]){
            splay(x,rt);
            return val[x];
        }else{
            k-=siz[son[x][0]]+cnt[x];
            x=son[x][1];
        }
    }
    return -1;
}

int pre(int v){
    find(v);
    if(val[rt]<v) return val[rt];
    int x=son[rt][0];
    if(!x) return -1;
    while(son[x][1]) x=son[x][1];
    splay(x,rt);
    return val[x];
}

int nxt(int v){
    find(v);
    if(val[rt]>v) return val[rt];
    int x=son[rt][1];
    if(!x) return -1;
    while(son[x][0]) x=son[x][0];
    splay(x,rt);
    return val[x];
}

void del(int v){
    find(v);
    if(val[rt]!=v) return;

    if(cnt[rt]>1){
        cnt[rt]--;
        siz[rt]--;
        return;
    }

    if(!son[rt][0] && !son[rt][1]){
        rt=0;
    }else if(!son[rt][0]){
        fa[son[rt][1]]=0;
        rt=son[rt][1];
    }else if(!son[rt][1]){
        fa[son[rt][0]]=0;
        rt=son[rt][0];
    }else{
        int x=son[rt][0];
        while(son[x][1]) x=son[x][1];
        splay(x,son[rt][0]);

        son[x][1]=son[rt][1];
        fa[son[rt][1]]=x;

        fa[x]=0;
        rt=x;
        up(rt);
    }
}

int main(){
    return 0;
}