Splay 的简介和相关题目

· · 个人记录

Splay 的错误记录和模板

https://shiroi-he.gitee.io/blog/2020/01/03/Splay%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0/

叶笑天的模板

lhm_ 的模板

一种常用的平衡树,也是LCT的必备工具

很久以前就学了平衡树(treap),后来不常用,给忘了。后来听学长说Splay更有用,然后学了Splay,然后除了LCT用了一下,也没怎么用,也忘了。于是现在我把Splay捡一捡。

该拔的毒瘤终将要拔!

警告:

1.准备好各种调题工具和方法

2.为保证正确性,该有的特判一定要有.

3.不要忘记pushup和pushdown

4.细心细心再细心!

此外,还有:

以下两点至少要保证一点:

1.不用0号节点来更改其他节点信息

2.不让0号节点的信息被更改

知识精要

Splay本质是一棵二叉搜索树,其性质是某关键字的中序序列是单调递增的每个节点的关键字权值应该大于其左子树节点,小于其右子树节点

Splay 的使用

Splay 的常见应用为:维护一个集合(不常用,通常可以用set配合权值树状数组来代替),维护一个序列(常用,支持插入删除,区间翻转,区间加, 区间推平等等),以及维护一条链(详见LCT)。

Splay 通常有两种用法:

  1. 节点编号无意义,二叉树的中序序列以节点权值为关键字排序。一般用于维护集合

  2. 节点编号无意义,二叉树的中序序序列以节点在序列中的位置为关键字排序。一般用于维护序列。

  3. 节点编号为给定编号,二叉树的中序序列以给定编号所代表的东西的位置(如深度) 为关键字排序。一般用于 Splay 的灵活考察题,以及LCT(还没有学,学完后再回来改)

主要函数:

其它常用函数如下:

其余函数要依据题意来写。

技能展示

例题Ⅰ:P3369 【模板】普通平衡树

以值为关键字,维护一个 集合

当需要支持查排名和查第k大的时候,就不得不使用Splay,否则可以用 set 代替(常数也会小一些)

特殊的函数:

Code:

int n;
int rt, tot;
int son[N][2], fa[N], siz[N], cnt[N], val[N];
inline void pushup(int cur) {
    siz[cur] = cnt[cur];
    if (son[cur][0])    siz[cur] += siz[son[cur][0]];
    if (son[cur][1])    siz[cur] += siz[son[cur][1]];
}
inline bool get_which(int cur) {
    return son[fa[cur]][1] == cur;
}
inline void rotate(int cur) {
    int faa = fa[cur], fafa = fa[faa];
    int flag = get_which(cur);
    if (fafa)   son[fafa][get_which(faa)] = cur; fa[cur] = fafa;
    son[faa][flag] = son[cur][flag ^ 1]; if (son[cur][flag ^ 1])    fa[son[cur][flag ^ 1]] = faa;
    son[cur][flag ^ 1] = faa; fa[faa] = cur;
    pushup(faa);
}
inline void splay(int cur, int goal) {
    for (register int faa = fa[cur]; faa != goal; rotate(cur), faa = fa[cur]) 
        if (fa[faa] != goal)    rotate(get_which(cur) == get_which(faa) ? faa : cur);
    pushup(cur);
    if (!goal)  rt = cur;
}

inline void ins(int x) {
    int pre = 0, np = rt;
    while (np && val[np] != x)  pre = np, np = son[np][x > val[np]];
    if (np && val[np] == x) return cnt[np]++, siz[np]++, splay(np, 0), void();
    np = ++tot; siz[np] = cnt[np] = 1; val[np] = x;
    fa[np] = pre; if (pre)  son[pre][x > val[pre]] = np;
    son[np][0] = son[np][1] = 0;
    splay(np, 0);
}

inline int k_th(int k) {
    int np = rt;
    while (1) {
        if (son[np][0] && k <= siz[son[np][0]]) { np = son[np][0]; continue; }
        int tmp = siz[son[np][0]] + cnt[np];
        if (k <= tmp)   return val[np];
        k -= tmp;
        np = son[np][1];
    }
}

inline void find(int x) {
    int np = rt;
    while (val[np] != x && son[np][x > val[np]])    np = son[np][x > val[np]];
    splay(np, 0);
}

inline int get_rk(int x) {
    find(x); return siz[son[rt][0]] + 1;
}

inline int get(int x, int type) {
    find(x);
    if (val[rt] < x && type == 0)   return rt;
    if (val[rt] > x && type == 1)   return rt;
    int np = son[rt][type];
    while (np && son[np][type ^ 1]) np = son[np][type ^ 1];
    return np;
}

inline void del(int x) {
    int pre = get(x, 0), nxt = get(x, 1);
    splay(pre, 0); splay(nxt, pre);
    int np = son[nxt][0];
    if (cnt[np] > 1) {
        cnt[np]--; siz[np]--;
        splay(np, 0);
        return ;
    }
    son[nxt][0] = 0;
    pushup(nxt); pushup(rt);
}

int main() {
    read(n);
    rt = tot = 1; cnt[1] = siz[1] = 1, val[1] = inf;
    ins(-inf);
    int opt, aa;
    for (register int i = 1; i <= n; ++i) {
        read(opt); read(aa);
        if (opt == 1) {
            ins(aa);
        } else if (opt == 2) {
            del(aa);
        } else if (opt == 3) {
            printf("%d\n", get_rk(aa) - 1);
        } else if (opt == 4) {
            printf("%d\n", k_th(aa + 1));
        } else if (opt == 5) {
            printf("%d\n", val[get(aa, 0)]);
        } else if (opt == 6) {
            printf("%d\n", val[get(aa, 1)]);
        }
    }
    return 0;
}

习题:

线段树套平衡树,变量名重复比较严重,注意区分。

my record

线段树套平衡树。这次我使用namespace,成功解决变量名重复问题。本题数据较水,暴力(O(n^2))可过(毕竟我O(n^2logn)都能得到95分)

my record

例题Ⅱ:P3391 区间翻转

以数组下标为关键字,故寻找代表某数组的节点时要 find

值得注意的是,不能想当然地去 pushdown,要在 splay() 中先找出从 cur 到根路径的所有节点,自上而下依次 pushdown

引入新的常用函数: buildprint,这两个函数都是利用 SplayBST 性质,搞二叉树的前序遍历。还是很好理解的。

细心!!

Code:
int n, m, ttot, root;
int key[N], num[N], son[N][2], fa[N], siz[N];
bool tag[N];
inline bool get_which(int cur) {
    return son[fa[cur]][1] == cur;
}
inline void pushup(int cur) {
    siz[cur] = 1;
    if (son[cur][0])    siz[cur] += siz[son[cur][0]];
    if (son[cur][1])    siz[cur] += siz[son[cur][1]];
}
inline void pushdown(int cur) {
    if (!tag[cur])  return ;
    tag[son[cur][0]] ^= 1; tag[son[cur][1]] ^= 1;
    swap(son[cur][0], son[cur][1]);
    tag[cur] = 0;
}
inline void rotate(int cur) {
    int faa = fa[cur], fafa = fa[faa];
    bool flag = get_which(cur);
    fa[cur] = fafa; if (fafa)   son[fafa][get_which(faa)] = cur;
    son[faa][flag] = son[cur][flag ^ 1]; if (son[cur][flag ^ 1])    fa[son[cur][flag ^ 1]] = faa;
    son[cur][flag ^ 1] = faa; fa[faa] = cur;
    pushup(faa); pushup(cur);
}
int stk[N], stop;
inline void splay(int cur, int goal) {
    int p = cur; 
    while (p != root)   stk[++stop] = p, p = fa[p];
    pushdown(root);
    while (stop)    pushdown(stk[stop--]);

    for (register int faa = fa[cur]; faa != goal; rotate(cur), faa = fa[cur]) {
        if (fa[faa] != goal)    rotate(get_which(cur) == get_which(faa) ? faa : cur);
    }
    pushup(cur);
    if (goal == 0)  root = cur;
}
inline int find(int k) {
    int p = root;
    while (1) {
        pushdown(p);
        if (k <= siz[son[p][0]]) {
            p = son[p][0]; continue;
        }
        int tmp = siz[son[p][0]] + 1;
        if (k <= tmp)   break;
        k -= tmp;
        p = son[p][1];
    }
    return p;
}
inline void split(int x, int y) {
    splay(x, 0); splay(y, root);
}
int build(int L, int R, int faa) {
    if (L > R)  return 0;
    int mid = (L + R) >> 1;
    int cur = ++ttot;
    siz[cur] = 1;
    fa[cur] = faa;
    key[cur] = num[mid];
    son[cur][0] = build(L, mid - 1, cur);
    son[cur][1] = build(mid + 1, R, cur);
    pushup(cur);
    return cur;
}
inline void reverse(int l, int r) {
    r += 2; l = find(l); r = find(r);
    split(l, r);
    tag[son[son[root][1]][0]] ^= 1;
}
inline void print(int cur) {
    pushdown(cur);
    if (son[cur][0])    print(son[cur][0]);
    if (key[cur] >= 1 && key[cur] <= n) printf("%d ", key[cur]);
    if (son[cur][1])    print(son[cur][1]);
}
int main() {
    read(n); read(m); num[1] = -inf; 
    for (register int i = 1; i <= n; ++i) num[i + 1] = i;
    num[n + 2] = inf; root = build(1, n + 2, 0);
    register int aa, bb;
    for (register int i = 1; i <= m; ++i) {
        read(aa); read(bb);
        reverse(aa, bb);
    }
    print(root);
    puts("");
    return 0;
}

例题Ⅲ:P2042 [NOI2005]维护数列

(重题:P2710 数列)

以数组下标为关键字,维护一个 序列 (能完成类似线段树的功能,且支持插入删除翻转)

需要回收内存,即允许用已经del过的点:

int sta[N], top;
inline int nwnode() {
    int cur = top ? sta[top--] : ++tot;
    siz[cur] = val[cur] = sum[cur] = lmx[cur] = rmx[cur] = mx[cur] = 0;
    rev[cur] = set[cur] = son[cur][0] = son[cur][1] = fa[cur] = 0;
    return cur;
}

inline void del(int cur) {
    if (!cur)   return ;
    sta[++top] = cur;
    del(son[cur][0]); del(son[cur][1]);
}
...
cur = nwnode();
...

注意:pushup, pushdown!!

其余部分代码如下:


inline void pushup(int cur) {
    int ls = son[cur][0], rs = son[cur][1];
    siz[cur] = 1; sum[cur] = val[cur];
    siz[cur] += siz[ls], sum[cur] += sum[ls]; 
    siz[cur] += siz[rs], sum[cur] += sum[rs];
    lmx[cur] = max(lmx[ls], sum[ls] + val[cur] + lmx[rs]);
    rmx[cur] = max(rmx[rs], sum[rs] + val[cur] + rmx[ls]);
    mx[cur] = max(val[cur] + rmx[ls] + lmx[rs], max(mx[ls], mx[rs]));
}

inline void push_rev(int cur) {
    rev[cur] ^= 1;
    swap(son[cur][0], son[cur][1]);
    swap(lmx[cur], rmx[cur]);
}

inline void push_set(int cur, int vall) {
    if (!cur)   return ;
    set[cur] = true;
    val[cur] = vall; sum[cur] = vall * siz[cur];
    if (vall <= 0)  lmx[cur] = rmx[cur] = 0, mx[cur] = vall;
    else    lmx[cur] = rmx[cur] = mx[cur] = sum[cur];
}   

inline void pushdown(int cur) {
    if (set[cur])
        push_set(son[cur][0], val[cur]), push_set(son[cur][1], val[cur]), set[cur] = 0;
    if (rev[cur])
        push_rev(son[cur][0]), push_rev(son[cur][1]), rev[cur] = 0;
}

int list[N], tp;
inline void splay(int cur, int goal) {

    tp = 0;
    for (register int np = cur; np != goal; np = fa[np])    list[++tp] = np;
    if (goal)   list[++tp] = goal;
    while (tp)  pushdown(list[tp--]);       //attention!!

    for (register int faa = fa[cur]; faa != goal; rotate(cur), faa = fa[cur])
        if (fa[faa] != goal)    rotate(get_which(cur) == get_which(faa) ? faa : cur);
    pushup(cur);        //attention!!
    if (!goal) rt = cur;
}

void build(int L, int R, int &cur) {
    cur = nwnode();
    int mid = (L + R) >> 1;
    lmx[cur] = rmx[cur] = max(a[mid], 0);
    val[cur] = sum[cur] = mx[cur] = a[mid];
    if (L < mid)    build(L, mid - 1, son[cur][0]);
    if (R > mid)    build(mid + 1, R, son[cur][1]);
    fa[son[cur][0]] = fa[son[cur][1]] = cur;
    pushup(cur);        //attention!!
}

inline int kth(int k) {
    int np = rt;
    while (1) {
        pushdown(np);       //attention!!
        if (son[np][0] && k <= siz[son[np][0]]) { np = son[np][0]; continue; }
        int tmp = son[np][0] ? siz[son[np][0]] + 1 : 1;
        if (k <= tmp)   return np;
        k -= tmp;
        np = son[np][1];
    }
}

inline void split(int l, int r) {
    l = kth(l - 1);
    r = kth(r + 1);
    splay(l, 0);
    splay(r, l);
}

int main() {
    read(n); read(m);
    mx[0] = -inf;
    a[1] = a[n + 2] = -inf;
    for (register int i = 1; i <= n; ++i)   read(a[i + 1]);
    build(1, n + 2, rt);
    char opt[15];
    register int aa, tt, num;
    while (m--) {
        scanf("%s", opt);
        if (opt[0] == 'I') {//Insert
            read(aa); read(tt); aa++;
            for (register int i = 1; i <= tt; ++i)  read(a[i]);
            register int rtt; build(1, tt, rtt);
            split(aa + 1, aa);//attention!!!
            register int rs = son[rt][1];
            fa[rtt] = rs; son[rs][0] = rtt;
            pushup(rs); pushup(rt);     //attention!!
        } else if (opt[0] == 'D') {//Delete
            read(aa); read(tt); aa++;
            if (!tt)    continue;
            split(aa, aa + tt - 1);
            register int rs = son[rt][1];
            del(son[rs][0]);
            fa[son[rs][0]] = 0; son[rs][0] = 0;
            pushup(rs); pushup(rt);     //attention!!
        } else if (opt[0] == 'M' && opt[2] == 'K') {//Make same
            read(aa); read(tt); read(num); aa++;
            if (!tt)    continue;
            split(aa, aa + tt - 1);
            register int rs = son[rt][1];
            push_set(son[rs][0], num);
            pushup(rs); pushup(rt);     //attention!!
        } else if (opt[0] == 'R') {//reverse
            read(aa); read(tt); aa++;
            if (!tt)    continue;
            split(aa, aa + tt - 1);
            register int rs = son[rt][1];
            push_rev(son[rs][0]);
            pushup(rs); pushup(rt);     //attention!!
        } else if (opt[0] == 'G') {//get sum
            read(aa); read(tt); aa++;
            if (!tt) { puts("0"); continue; }
            split(aa, aa + tt - 1);
            register int rs = son[rt][1];
            printf("%d\n", sum[son[rs][0]]);
        } else if (opt[0] == 'M' && opt[2] == 'X') {//get max sum
            printf("%d\n", mx[rt]);
        }
    }
    return 0;
}

例题Ⅳ:P2596 [ZJOI2006]书架

题意:

维护一个编号序列(即由一堆带编号的点组成的序列),支持:将编号cur 的点置顶(放到最前面)/置底(放到最后面);将编号cur 的点与其 前驱/后继 交换位置;查询某编号位置;查询某位置编号

题解:

以编号为节点编号(方便查询),以序列中的位置为关键字,建立 Splay 维护。

考察灵活运用 Splay 的能力。

特殊函数:

inline void Cut(int cur) {//with his father
    int faa = fa[cur];
    son[faa][get_which(cur)] = 0; fa[cur] = 0;
    siz[faa] -= siz[cur];
}
inline void Swap(int x, int y) {
    splay(x, 0); splay(y, x);
    int rt1 = son[x][0], rt2 = son[y][1];
    son[x][0] = 0, son[y][1] = x;
    son[x][1] = rt2, son[y][0] = rt1;
    fa[x] = y, fa[y] = 0;
    fa[rt1] = y, fa[rt2] = x;
    pushup(x); pushup(y);
    root = y;
}
...
//(in main)
if (opt[0] == 'T') {//Top
    register int cur;
    read(cur);
    int pos = get_pos(cur);//including splay(cur, 0)
    if (pos == 0)   continue;
    int rt1 = son[cur][0], rt2 = son[cur][1];
    Cut(rt1); Cut(rt2);
    root = rt1;
    int tmp = find(1);
    splay(tmp, 0);
    son[n + 1][1] = cur; fa[cur] = n + 1;
    pushup(n + 1); pushup(root);
    tmp = find(pos + 1);
    splay(tmp, 0);
    son[tmp][1] = rt2; fa[rt2] = tmp;
    pushup(tmp);
} else {//Bottom
    register int cur;
    read(cur);
    int pos = get_pos(cur);
    if (pos == n - 1)   continue;
    int tmp = find(n);
    splay(cur, 0);
    int rt1 = son[cur][0], rt2 = son[cur][1];
    Cut(rt1); Cut(rt2);
    root = rt2;
    splay(tmp, 0);
    son[n + 2][0] = cur; fa[cur] = n + 2;
    pushup(n + 2); pushup(root);
    tmp = find(0);
    splay(tmp, 0);
    son[tmp][0] = rt1; fa[rt1] = tmp;
    pushup(tmp);
}

习题

注意!