数据结构-伸展树(Splay)

· · 算法·理论

伸展树,英文名叫 Splay Tree,是著名的 Tarjan 先生发明的一个神奇的数据结构。其本质是一棵二叉查找树,可以实现平衡树的所有操作,而且它的所有操作都基于伸展操作(将某个节点调整到根)。

另外提一嘴,这玩意叫 Splay,并不叫 Spaly。

基本操作

平衡树所有的操作 Splay 都可以做,而且在树上启发式合并的时候 Treap 甚至有时间复杂度上的优势。

旋转操作

与 treap 的旋转操作类似,这里将左旋和右旋合二为一了。先上一张经典的图:

我们把节点 x 的父亲称作 y,把 y 的父亲(如果有的话)称作 z。首先我们判断下 xy 的左儿子还是右儿子,记为 k,若 k = 0 说明是左儿子, k = 1 说明是右儿子。先把 x 偏向 y 那一侧的儿子跟 y 相连,接着把 xy 相连,最后把 xz 相连(如果有的话)。不要忘了更新旋转后的 siz 等信息,更新时必须按照父子关系更新

代码:

int get(int x){return ch[fa[x]][1] == x;}
void rotate(int x) {
    int y = fa[x], z = fa[y] , k = get(x);//判断 
    ch[y][k] = ch[x][k ^ 1];
    if(ch[x][k ^ 1]) fa[ch[x][k ^ 1]] = y;// x 的某个子节点 & y 
    ch[x][k ^ 1] = y;
    fa[y] = x;// x & y 
    if(z) ch[z][ch[z][1] == y] = x;
    fa[x] = z;//x & z
    push_up(y);push_up(x);
}

Splay操作

伸展树最核心的操作,其目的就是把节点 x 调整都根节点。将 x 的父节点称作 pp 的父亲称作 g。如何操作主要取决于 3 个因素: p 是不是根节点,xp 的左儿子还是右儿子, pg 的左儿子还是右儿子。

如果 p 是根节点,那么旋一次即可,一步到位,非常简单。

接下来假设 p 不是根节点。有四种不同的情况:xp 左儿子且 pg 左儿子,xp 右儿子且 pg 右儿子,xp 右儿子且 pg 左儿子,xp 左儿子且 pg 右儿子。由于 1,2 两种情况类似, 3,4 两种情况类似,只分析第 1 种和第 3 种。

zig-zig / zag-zag :如图所示,对应第 1 种情况(第 2 种情况也是相同处理的方式,因为我们已经把左旋右旋合二为一了)。此时我们先将 g 旋转一次,再将 x 旋转一次,就能得到图右边的新形态。这里注意旋转时的顺序,如果顺序不同,会导致得到的树的形态不同,最终无法进行势能分析。

zig-zag/zag-zig:如图所示,对应第 3 种情况。此时我们是将 x 连续旋转两次,就能把它旋转到根。

这样子,我们每次可以将 x 的深度减少 2(最后一次可能是 1),最后总能将 x 调整到根节点。

代码:

void splay(int x) {
    while(fa[x]){
        if(fa[fa[x]]) rotate(get(x) == get(fa[x]) ? fa[x] : x);
        rotate(x);
    }
    rt = x; 
}

合并操作

这里的合并操作,要求伸展树 x 中每个元素都小于伸展树 y 中每个元素,与后文的启发式合并作区分。如果 x 树或者 y 树为空,那么直接返回另一棵树的根节点即可。否则将 x 中最大的元素 mx 旋到根,此时 mx 的右子树为空,直接把 y 接上就好了记得更新节点信息

删除操作

首先在原树中找到 x 这个数,把它旋到根,如果值为 x 的数不止一个,直接减掉一个就行;如果只有一个,那么合并它的左右儿子。

插入操作

在原树上找这个值,找得到就把这个值的次数加一,更新节点信息;找不到就新建一个节点。最后别忘了把这个点旋到根。

查询操作

x 的排名,就相当于找小于等于 x - 1 的数有多少个,最后加一就是结果。

求排名为 x 的数,不断判断以目前节点为根的子树大小与剩余排名的关系,若小于,去左子树找;否则如果左子树的大小加上这个值的次数小于等于剩余排名,返回这个数;前两者都不满足,就去右子树找。

求某个数的前驱,就先在树中找到这个数,接着找到它左子树里最靠右边的节点。后继同理,在右子树里找最左边的节点。

当然,找到某个点后还是要通过把这个点旋到根节点的方式,来保证均摊的时间复杂度。

总结

其实还是挺好写的,但是在访问一个点后容易忘掉把它旋到根节点,修改一个点的值后容易忘记更新节点信息。

当然像刘昊,花花这种对璇那么上心的人肯定是不会忘的啦

时间复杂度分析

这一部分其实才是伸展树算法的难点(但貌似不会也没有关系)。它采用了一个叫“势能分析” 的神奇分析方法,但众所周知,信息是一个看结果的学科,谁爱分析谁分析去吧。

定理:假设最多有 n 个点,在进行 m 次 Splay 操作后,总复杂度是 O(m \log n + n \log n)

证明谁爱写谁写,反正不会没关系。有了这个定理,我们就可以证明对于所有操作,均摊都是 O(\log n) 的。

放一张 oi-wiki 的图上来,感兴趣的读者可以自行阅读。

#include <bits/stdc++.h>
using namespace std; 
const int inf = 1e9 + 100;
const int N = 1e5 + 100;
int ch[N][2], fa[N];  
int val[N], siz[N], cnt[N], idx;
int rt;
int New(int x){
    val[++idx] = x;
    siz[idx] = cnt[idx] = 1;
    return idx;
}
int get(int x){return ch[fa[x]][1] == x;}
void push_up(int x){siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];}
void rotate(int x) {
    int y = fa[x], z = fa[y] , k = get(x);//判断 
    ch[y][k] = ch[x][k ^ 1];
    if(ch[x][k ^ 1]) fa[ch[x][k ^ 1]] = y;// x 的某个子节点 & y 
    ch[x][k ^ 1] = y;
    fa[y] = x;// x & y 
    if(z) ch[z][ch[z][1] == y] = x;
    fa[x] = z;//x & z
    push_up(y);push_up(x);
}
void splay(int x) {
    while(fa[x]){
        if(fa[fa[x]]) rotate(get(x) == get(fa[x]) ? fa[x] : x);
        rotate(x);
    }
    rt = x; 
    fa[x] = 0;
    push_up(x);
}
void init(){
    New(-inf), New(inf);
    rt = 1, ch[1][1] = 2, siz[1] = 2, fa[2] = 1;
}
int get_rank(int u, int k,int f){
    if(!u){
        splay(f);
        return 0;
    }
    if(k == val[u]){
        int res = siz[ch[u][0]] + cnt[u];
        splay(u);
        return res;
    }
    if(k < val[u]) return get_rank(ch[u][0], k, u);
    int res = siz[ch[u][0]] + cnt[u];
    return get_rank(ch[u][1], k, u) + res;
}
int get_num(int u, int k, int f){
    if(!u){
        splay(f);
        return inf; 
    }
    if(siz[ch[u][0]] >= k) return get_num(ch[u][0], k, u);
    if(siz[ch[u][0]] + cnt[u] >= k){
        splay(u);
        return val[u];
    }
    return get_num(ch[u][1], k - siz[ch[u][0]] - cnt[u], u);
}
void ins(int u, int k , int f){
    if(!u){
        u = New(k);
        fa[u] = f;
        ch[f][k > val[f]] = u;  
        push_up(u);
        push_up(f);
        splay(u);
        return ;
    }
    if(val[u] == k){
        cnt[u]++;
        push_up(u);
        push_up(f); 
        splay(u);
        return ;
    }
    ins(ch[u][k > val[u]], k, u);
}
int merge(int x, int y){
    if(!x){
        fa[y] = 0;
        return y;
    }
    if(!y){
        fa[x] = 0;
        return x;
    }
    int t = x;
    while(ch[t][1]) t = ch[t][1];
    splay(t); 
    ch[t][1] = y;fa[t] = 0;
    fa[y] = t;
    return t;
}
void clear(int u){
    fa[ch[u][0]] = fa[ch[u][1]] = 0;
    cnt[u] = siz[u] = val[u] = ch[u][0] = ch[u][1] = 0;
}
void del(int u, int k, int f){
    if(!u){
        splay(f);
        return ;    
    }
    if(val[u] == k){
        splay(u);
        if(cnt[u] > 1) {
            cnt[u]--;
            push_up(u);
            return ;
        }   
        fa[ch[u][0]] = fa[ch[u][1]] = 0;
        rt = merge(ch[u][0], ch[u][1]);
        fa[u] = cnt[u] = siz[u] = val[u] = ch[u][0] = ch[u][1] = 0;
        return ;
    }
    del(ch[u][k > val[u]], k, u);
    push_up(u);
}
int get_pre(int k){
    int ans = 1, u = rt;
    while(u){
        if(val[u] == k){
            if(ch[u][0]){
                u = ch[u][0];
                while(ch[u][1]) u = ch[u][1];
                ans = u;
            }
            break;
        }   
        if(val[u] < k && val[u] > val[ans]) ans = u;
        u = ch[u][k > val[u]];
    }
    splay(ans);
    return val[ans];
}
int get_nxt(int k){
    int ans = 2, u = rt;
    while(u){
        if(val[u] == k){
            if(ch[u][1]){
                u = ch[u][1];
                while(ch[u][0]) u = ch[u][0];
                ans = u;
            }
            break;
        }   
        if(val[u] > k && val[u] < val[ans]) ans = u;
        u = ch[u][k > val[u]];
    }
    splay(ans);
    return val[ans];
}
int main(){ 
    int n;cin >> n;
    init(); 
    for(int i = 1;i <= n;i++) {
        int opt, x;
        cin >> opt >> x;
        if(opt == 1) ins(rt, x, 0);
        if(opt == 2) del(rt, x, 0);
        if(opt == 3) cout << get_rank(rt, x - 1, 0) << endl;
        if(opt == 4) cout << get_num(rt, x + 1, 0) << endl; 
        if(opt == 5) cout << get_pre(x) << endl;
        if(opt == 6) cout << get_nxt(x) << endl;    
    }
    return 0;
}

树上启发式合并

(定理 2 肯定是对的,但笔者显然不想,也不会去证明)

问题:对于 x 个 满足二叉查找树性质的树,假如它们一共有 n 个节点,我们每次任意选择两棵树,将小的那棵树当中每一个节点往大的那棵树合并,直到合并到 1 棵树。

定理1:正常平衡树可以在 O(n \log^2n) 的复杂度内解决此问题。

证明1:每次合并整体分析比较难搞,不如考虑每个元素对于总时间复杂度的贡献。当 x 所在的树需要和一棵更大的 Splay 树合并的时候,才会需要不超过 O(\log n) 的时间复杂度把它插入到新的树里。而合并后 x 所在树的大小至少翻倍,所以对于每个数,它至多插了 O(\log n) 次。一共有 n 个点,所以时间复杂度为 O(n\log^2n)

定理2:对于一个 n 个点的 Splay 树,有 m 次插入/删除/查找操作,那么可以在 O(n + m + \sum \log(d_i + 1)) 的时间复杂度内完成这 m 次操作,其中 d_i 表示第 i 个操作的元素与第 i - 1 个元素在操作时排名之差。

这个我不会证。

定理2的引理: 对于两棵大小分别为 nm 的 Splay 树(m \le n),合并它们的复杂度是 O(m (\log n -\log m ))

定理3 :Splay 可以在 O(n \log n) 的时间复杂度内完成上述操作。

证明3:我们将 O(n + m + \sum \log(d_i + 1)) 化为 O(n + m) + O(\sum \log(d_i + 1))。对于前半段,这是启发式合并的模板,每个点可以均摊 O(1) 的进行插入(放在一起分析),而每个节点所在的子树往上合并时大小翻倍,所以每个点最多被合并 O(\log n) 次,总共 n 个点,复杂度是 O(n\log n);考虑每棵树对于时间复杂度的贡献。与证明 1 相同,假设这 x 个 Splay 树的大小分别是 a_1,a_2\dots a_x,记 s_i = \sum_{j = 1}^i a_i,那么对于某个元素,其对总复杂度的贡献不会超过 O( \log\dfrac{s_2}{s_1} + \log\dfrac{s_3}{s_2} + \dots + \log\dfrac{s_x}{s_{x - 1}}) = O(\log s_x) < O(\log n),那么最多 n 个节点,不会超过 O(n\log n)

所以总复杂度是 O(n\log n)

总而言之,当我们需要不断合并两个平衡树的时候,用伸展树来实现可以让时间复杂度更优。