学习笔记:平衡树

· · 个人记录

比较厉害的一种数据结构

二叉搜索树 BST

简介

二叉搜索树是一种二叉树的树形数据结构,其定义如下:

  1. 空树是二叉搜索树。
  2. 若二叉搜索树的左子树不为空,则其左子树上所有点的附加权值均小于其根节点的值。
  3. 若二叉搜索树的右子树不为空,则其右子树上所有点的附加权值均大于其根节点的值。
  4. 二叉搜索树的左右子树均为二叉搜索树。

摘自 oi-wiki 。

操作

二叉搜索树支持六个基本操作:

建树

并没有什么复杂操作,但为避免意外的越界情况,通常会先插入两个值分别为 -infinf 的节点。

插入

我们可以从根节点开始,用循环或者递归方式,对于新插入的值 v ,比较其与当前节点值 x 的大小关系,若 v < x 则尝试去其左子树,若 v > x 尝试去其右子树,若 v = x 则直接让当前节点的计数变量 cnt1 ,若根本不存在当前节点则直接创建权值为 v 的新节点。

删除

首先按照类似于插入操作的方式找到这个节点 p

按排名查值

也是类似于插入,但需要额外增添一个变量 size ,表示已 p 为根节点的子树中 cnt 之和。

假设当前要查找排名为 k 的数,即第 k 小的数。记 p_lp 的左子节点,p_rp 的右子节点。 具体而言,到达一个节点 p

按值查排名

类似于插入的检索方式,把途经的节点的左子树的 size 和途经节点的 cnt 加起来,直到找到符合的节点,(找到的这个节点不用加它的 cnt )并对累积的和加 1 ,就是排名。

查前驱/后继

查前驱和后继是类似的。以前驱为例。

小结

二叉搜索树是比较强大的数据结构,处理理想的随机数据时,每次操作时间复杂度是 O(\log n) 。但其缺陷在于,如果数据是人为制造而非随机生成,树可能退化为链,进而导致时间复杂度退化为 O(n)

平衡树

模板链接

【模板】普通平衡树

平衡树的概念

我们通过一些特殊处理,以保证二叉搜索树不会退化为链,并且结构“平衡”(使得操作复杂度可观)。这些处理与二叉搜索树的结合就称为平衡树

平衡树有很多种,常见的有 Treap, Splay, 红黑树,替罪羊树 等。C++中一些STL容器就是用红黑树实现的。

竞赛中常用的是 Treap 和 Splay,因为大多时候都够用,并且写起来没那么复杂。

旋转:平衡树的核心操作

事实上,一颗二叉搜索树的中序遍历,对应的就是维护的序列的不降序排列方式。而我们知道,同一中序遍可以对应不同形态的树,因此我们可以在不改变二叉搜索树中序遍历的情况下改变数的形态。

平衡树的核心操作,旋转,就是这样,可以在满足二叉搜索树性质的情况下,对树的结构做出改变。

旋转分为左旋 (zag) 和右旋 (zig) 。

其实很难用语言表述这一过程,所以就直接挂个图 上图表述已经很清晰了。

关于旋转,有一个口诀,方便记忆和写程序:

左旋拎右左挂右,右旋拎左右挂左

至于旋转不改变二叉搜索树的性质这一结论,这里不作证明因为压根不会

后面所涉及的代码中,都没有记录父亲节点的信息,因此我们调用旋转函数所传参数为父亲节点,即上图右旋 zig(y) 这种形式 。

另外需要注意的是,我们会改变这棵树的结构,但转移节点维护的信息略显麻烦,因此我们只改变节点左右儿子指针指向的节点编号。而为便于写程序简便,调用函数要传地址。(比如zig(p),那么调用完毕后,p 节点被旋转到下方,而其左儿子到达它原本位置,因此 p 的父亲的儿子指针应指向 p 的左儿子。若直接传地址,并且在函数结尾把 p 地址 赋值为其左儿子的地址,那么 p 的父亲的儿子指针自然指向了 p 的左儿子。)

我们按照口诀写出代码:

inline void zig(int &p) {
    int q = t[p].l;
    t[p].l = t[q].r, t[q].r = p, p = q;//注意p为引用
    updatet(t[p].r), updatet(p);
}
inline void zag(int &p) {
    int q = t[p].r;
    t[p].r = t[q].l, t[q].l = p, p = q;//注意p为引用 
    updatet(t[p].l), updatet(p);
}

但是,怎样的旋转才能让树维持在平衡状态却是问题。接下来讨论的几种平衡树各有各的方法。

Treap

简介

Treap 就是把二叉搜索树 (Tree) 和堆 (Heap) 结合(就连单词本身亦是如此)。每个节点有两个权值, 一个是真实的权值,一个是程序随机生成权值。以真实权值为关键字,排列方式满足二叉搜索树的性质;以随机权值为关键字,排列方式满足堆的性质。也就是说,维护的节点将同时满足堆和二叉搜索树的性质和结构。

由于堆必定不会退化为链,因此 Treap 不会退化为链。那么可以保证操作时间复杂度维持在 O(\log n)

所以,只要树不满足堆的性质,我们就不停旋转直到它满足,以保证树的平衡。

操作

插入

插入,Treap需要先按二叉搜索树的方式插入,并生成一个随机权值。完成后不停把这个更新的点向上旋转,直到满足堆的性质。这与堆的 up 操作是类似的,只不过每次不是纯粹交换两个点,而是旋转。

删除

删除,类似堆的想法,我们把要删除的点向下旋转到叶子,然后直接删除。

其它操作

其它操作类似于二叉搜索树。

源码

/**********************************

Problem: luogu - P3369 - 【模板】普通平衡树 (treap)

State: Accepted

From: 【模板】 

Algorithm: treap 

Author: cyh_toby

Last updated on 2020.7.19

**********************************/
#include <cstdio>
#include <cstdlib>
#include <algorithm>

using namespace std;

const int N = 1e5+5, inf = 1e8;

struct Treap{
    int l, r;
    int val, dat;//真实权, 随机权 
    int cnt, siz;
} t[N];
int n, rt, tot;

inline int createt(int x) {
    t[++tot].val = x, t[tot].dat = rand(), t[tot].cnt = t[tot].siz = 1;
    return tot;
}
inline void updatet(int p) {
    t[p].siz = t[p].cnt + t[t[p].l].siz + t[t[p].r].siz;
}
int rnk(int p, int v) {
    if (p == 0) return 0;
    if (v == t[p].val) return t[t[p].l].siz+1;
    if (v < t[p].val) return rnk(t[p].l, v);
    return rnk(t[p].r, v) + t[t[p].l].siz + t[p].cnt;
}
int kth(int p, int k) {
    if (p == 0) return inf;
    if (t[t[p].l].siz >= k) return kth(t[p].l, k);
    if (t[t[p].l].siz + t[p].cnt >= k) return t[p].val;
    return kth(t[p].r, k - t[t[p].l].siz - t[p].cnt);
}
inline void zig(int &p) {
    int q = t[p].l;
    t[p].l = t[q].r, t[q].r = p, p = q;
    updatet(t[p].r), updatet(p);
}
inline void zag(int &p) {
    int q = t[p].r;
    t[p].r = t[q].l, t[q].l = p, p = q;
    updatet(t[p].l), updatet(p);
}
inline void buildt() {
    createt(-inf), createt(inf);
    rt = 1, t[rt].r = 2;
    if (t[1].dat < t[2].dat)
        zag(rt);
}
void ins(int &p, int v) {
    if (p == 0) {
        p = createt(v);
        return;
    }
    if (v == t[p].val) {
        t[p].cnt++, updatet(p);
        return;
    }
    if (v < t[p].val) {
        ins(t[p].l, v);
        if (t[t[p].l].dat > t[p].dat) zig(p);
    }
    if (v > t[p].val) {
        ins(t[p].r, v);
        if (t[t[p].r].dat > t[p].dat) zag(p);
    }
    updatet(p);
}
//若要用kth和rnk完成pre和nxt, 需要修改kth和rnk函数细节 
int pre(int p, int v) {
    if (!p)
        return -inf;
    if (v <= t[p].val)
        return pre(t[p].l, v);
    int res = pre(t[p].r, v);
    return res == -inf? t[p].val : res;
}
int nxt(int p, int v) {
    if (!p)
        return inf;
    if (v >= t[p].val)
        return nxt(t[p].r, v);
    int res = nxt(t[p].l, v);
    return res == inf? t[p].val : res;
}
void del(int &p, int v) {
    if (p == 0) return;
    if (v == t[p].val) {
        if (t[p].cnt > 1) {
            t[p].cnt--, updatet(p);
            return;
        }
        if (t[p].l || t[p].r) {//非叶子 
            if (t[p].r == 0 || t[t[p].l].dat > t[t[p].r].dat)
                zig(p), del(t[p].r, v);
            else
                zag(p), del(t[p].l, v);
            updatet(p);
        }
        else p = 0;//叶子
        return;
    }
    if (v < t[p].val)
        del(t[p].l, v);
    else
        del(t[p].r, v);
    updatet(p);
}

int main()
{
    buildt();
    scanf("%d", &n);
    while (n--) {
        int op, x;
        scanf("%d%d", &op, &x);
        switch (op) {
            case 1:
                ins(rt, x);
                break;
            case 2:
                del(rt, x);
                break;
            case 3:
                printf("%d\n", rnk(rt, x)-1);
                break;
            case 4:
                printf("%d\n", kth(rt, x+1));
                break;
            case 5:
                printf("%d\n", pre(rt, x));
                break;
            case 6:
                printf("%d\n", nxt(rt, x));
                break;
        }
    }
    return 0;
}

Splay

简介

Splay 的核心操作是 Splay 。感觉这句话可以无限循环下去。

Splay 由丹尼尔·斯立特Daniel Sleator 和 罗伯特·恩卓·塔扬Robert Endre Tarjan 在1985年发明的。又是伟大的Tarjan!

一次 Splay 操作,其实就是两次旋转,称之为双旋。但这两次旋转在不同情况下,顺序、种类都是不同的。在 Splay 所规定的双旋操作下,可以尽可能维持树的平衡。Splay 通过把每次操作的点,按照它所规定的旋转方式旋转到根节点,以维持平衡。其中,旋转至根节点的旋转原则是:能双旋就双旋,否则单旋。

双旋:Splay的核心操作

具体而言,对于三个节点,其排列方式可能是共线也可能非共线。

对于共线的,我们要先将父亲向上旋转,再将要旋转的节点向上旋转。

对于非共线的,我们直接把要旋转的节点向上旋转两次即可。

至于为什么双旋可以使树平衡,这里不作论述。还是因为不会。

上图中的 zig() 和 zag() 都是上文 treap 中的函数,所传参数为父亲节点,原因是没有记录父亲节点。

但在下面的 Splay 代码中,记录了节点父亲节点的编号。这个目的是简化代码,并且在如下实现方式中可以把左旋和右旋合并为一个函数 rot() ,而这个函数所传参数为所要旋转上升的点编号。

操作

删除

首先把要删除的点伸展到根节点。如果其个数标记大于 1 ,对个数进行修改即可,不必继续进行下面的操作。

如果没有子节点,直接把这个点删除。

但是,理论上,如果我们进行了类似于 treap 中的 build 操作,插入了两个无穷节点,上面这个步骤可以省略。

如果它没有左子树或者右子树,直接让唯一的子节点成为根节点。

否则,即左右子节点都存在,那么可以找到它的前驱,把前驱 splay 到根节点,随后修改相关指针。 值得注意的是,前驱 splay 到根节点后,其右儿子必定为我们要删除的节点,并且我们要删除的节点必定没有左儿子。这一性质有助于简化代码。

此外,如果我们先写好了前驱函数 pre() ,其结尾必定会把找到的节点 splay 到根节点。那么,我们在删除操作的时候直接调用 pre() 函数就可以直接把后继 splay 到根节点了。这也有助于简化代码。

其它操作

没有什么特别之处。但在下面的源码中,pre 函数和 nxt 函数的实现细节是值得注意的。

需要牢记的是,所有操作结束后,均需把操作的点 splay 到根节点。

源码

无注释版本在此 。

/**********************************

Problem: luogu - P3369 - 【模板】普通平衡树 (splay)

State: Uncompleted

From: 【模板】 

Algorithm: splay 

Author: cyh_toby

Last updated on 2020.7.20

**********************************/
#include <cstdio>

using namespace std;

const int N = 1e5+5, inf = 1e8;

struct Splay {
    int rt, tot, f[N], ch[N][2], val[N], cnt[N], siz[N];
    #define upd(p) siz[p] = cnt[p] + siz[ch[p][0]] + siz[ch[p][1]]
    #define get(p) (p == ch[f[p]][1]) //判断是左儿子还是右儿子 
    #define crt(x) val[++tot] = x, cnt[tot] = siz[tot] = 1 //create建立新节点 
    inline void build() {
        crt(-inf), crt(inf);
        rt = 1, ch[rt][1] = 2, f[2] = rt;
        upd(rt);
    }
    inline void rot(int p) {
        int x = f[p], y = f[x], u = get(p), v = get(x);
        f[ch[p][u^1]] = x, ch[x][u] = ch[p][u^1];
        f[x] = p, ch[p][u^1] = x, upd(x), upd(p);
        f[p] = y;
        if (y) ch[y][v] = p;
    }
    inline void splay(int p, int g) {//把p伸展到g的子节点 
        while (f[p] != g) {
            int x = f[p], y = f[x];
            if (y == g) rot(p);//祖父是目标, 单旋 
            else rot(get(p) == get(x) ? x : p), rot(p);
        }
        if (!g) rt = p;
    }
    inline void ins(int v) {
        int x = rt, y = 0;
        while (1) {
            if (val[x] == v) {
                cnt[x]++, siz[x]++;
                upd(y), splay(x, 0);
                break;
            }
            y = x, x = ch[y][val[y] < v];
            if (!x) {
                crt(v);
                f[tot] = y, ch[y][val[y] < v] = tot, upd(y);
                splay(tot, 0);
                break;
            }
        }
    }
    inline int rnk(int v) {//最小的大于等于v的数的排名 
        int p = rt, res = 0; 
        while (1) { 
            if (v < val[p]) {
                if (!ch[p][0]) {
                    res++;
                    break;
                }
                p = ch[p][0];
            }
            else if (v == val[p]){
                res += siz[ch[p][0]] + 1;
                break;
            }
            else {
                res += siz[ch[p][0]] + cnt[p];
                if (!ch[p][1]) {
                    res++;
                    break;
                }
                p = ch[p][1];
            }
        }
        splay(p, 0);
        return res;
    }
    inline int kth(int k) {//查找排名为k的数 
        int p = rt;
        while (1) {
            if (k <= siz[ch[p][0]])
                p = ch[p][0];
            else if (k <= siz[ch[p][0]] + cnt[p])
                break;
            else
                k -= siz[ch[p][0]] + cnt[p], p = ch[p][1];
        }
        splay(p, 0);
        return val[p];
    }
    inline int pre(int x) {//注意, 根据rnk函数的定义, nxt与pre实现略有不同 
        return kth(rnk(x) - 1);
    }
    inline int nxt(int x) {
        return kth(rnk(x + 1));
    }
    inline void del(int v) {
        rnk(v);//借用rnk函数, 把第一个大于等于v的数splay到根节点
        if (v != val[rt]) return;//不存在v
        if (cnt[rt] > 1) {
            cnt[rt]--, siz[rt]--;
            return;
        }
        //不用判没有子节点, 因为buildt插入了两个无穷节点
        if (!ch[rt][0] || !ch[rt][1]) {
            rt = ch[rt][0] + ch[rt][1];
            f[rt] = 0;
            return;
        }
        int p = rt;
        pre(v);//找到v的前驱, splay到root,, 此时要删除的p必定是现在的根的右儿子, 且必定没有左儿子 
        f[ch[p][1]] = rt, ch[rt][1] = ch[p][1];
        siz[rt]--;
        return;
    }
} splay;

int main()
{
    splay.build();
    int n;
    scanf("%d", &n);
    while (n--) {
        int op, x;
        scanf("%d%d", &op, &x);
        switch (op) {
            case 1: splay.ins(x); break;
            case 2: splay.del(x); break;
            case 3: printf("%d\n", splay.rnk(x)-1); break;
            case 4: printf("%d\n", splay.kth(x+1)); break;
            case 5: printf("%d\n", splay.pre(x)); break;
            case 6: printf("%d\n", splay.nxt(x)); break;
        }
    }
    return 0;
} 

参考资料