题解 P3369 【【模板】普通平衡树(Treap/SBT)】

huangwenlong

2017-12-03 19:22:35

Solution

7个月前写过一篇题解,今天回来看下结果自己都看不下去,于是就来重写了。 # 递归版Splay 优点:**不用维护父指针!!!**刚开始学写非递归Splay的时候被父指针的维护坑了好久!!! 参考了大刘的《训练指南》。 ## 实现 **前排警告:前方存在大量结构体+指针** 首先我用名为node的结构体保存节点: ```cpp const int inf = 0x7fffffff; struct node *nil; // 哨兵节点,用于防止访问无效内存导致翻车 struct node { node *ch[2]; // ch[0]是左儿子指针,ch[1]是右儿子指针 int val, cnt, size; // 元素的值、元素个数(处理重复)、该节点构成的子树包含元素的个数 int cmp(int v) // 如代码所示,返回要寻找值为v的元素该向左走还是向右走 { if (v == val) return -1; else return v < val ? 0 : 1; } int cmpkth(int k) // 同上,返回要寻找第k小元素该向左走还是向右走 { if (k <= ch[0]->size) return 0; else if (k <= ch[0]->size + cnt) return -1; else return 1; } void pullup() { size = cnt + ch[0]->size + ch[1]->size; } // 用于插入或删除后重新计算size node(int v) : val(v), cnt(1), size(1) { ch[0] = ch[1] = nil; } // 普通的构造函数 } * root; void init() // 主要用来初始化哨兵节点 { nil = new node(0); root = nil->ch[0] = nil->ch[1] = nil; nil->size = nil->cnt = 0; } ``` 下面的说明中node既可以表示**节点**也可以表示**树**。 用0/1表示向左/向右,用`ch[0]`/`ch[1]`表示左儿子/右儿子指针,用`cmp(int v)`返回往左走/往右走,用异或运算取相反方向,这些都是来自大刘《训练指南》的技巧。因为平衡树中对称的情形太多了,合理运用这些技巧可以压缩代码量。 ### 伸展 所谓递归Splay,其实就是把寻找节点和伸展节点写在了一起,把递归寻找节点展开一层以后塞几行代码调用旋转。如果找不到这个值的节点,就会伸展最后一个访问到的节点。 ```cpp void rotate(node *&t, int d) //传引用很重要!! { node *k = t->ch[d ^ 1]; t->ch[d ^ 1] = k->ch[d]; k->ch[d] = t; t->pullup(), k->pullup(); // 注意此时k已经是t的父亲 t = k; } void splay(int v, node *&t) // 在树t中寻找值为v的节点,并伸展成为t的根节点;传引用很重要!! { int d = t->cmp(v); //下一步该走的方向 if (d != -1 && t->ch[d] != nil) //如果下一步可以走向一个合法结点 { int d2 = t->ch[d]->cmp(v); //下两步该走的方向 if (d2 != -1 && t->ch[d]->ch[d2] != nil) //如果下两步可以走向一个合法结点 { splay(v, t->ch[d]->ch[d2]); //先递归 if (d == d2) rotate(t, d2 ^ 1), rotate(t, d ^ 1); // zig-zig else rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); //zig-zag } else rotate(t, d ^ 1); // zig } // else t已经是终点 } void splaykth(int k, node *&t) // 同上,在树t中寻找第k小的节点,并伸展成为t的根节点;传引用很重要!! { int d = t->cmpkth(k); if (d == 1) k -= t->ch[0]->size + t->cnt; if (d != -1) { int d2 = t->ch[d]->cmpkth(k); int k2 = (d2 == 1) ? k - (t->ch[d]->ch[0]->size + t->ch[d]->cnt) : k; if (d2 != -1) { splaykth(k2, t->ch[d]->ch[d2]); if (d == d2) rotate(t, d2 ^ 1), rotate(t, d ^ 1); else rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); } else rotate(t, d ^ 1); } } ``` 虽然写起来是递归,但本质还是自底向上的。百度可以找到真正的自顶向下伸展的方法。 ----- 既然Splay可以变来变去,那么很多操作都有**“Splay特色”**的写法: ###求前驱/后继 逛了一圈发现都是暴力插入再求前驱/后继再删除的,下面介绍一个优雅的方法。或许是我原创的吧。 首先伸展X至根,如果X存在,根就会变成X,X的前驱就是左子树最大的值;再伸展左子树的最大值成为左子树的根,就是根节点的左儿子。求后继同理。 如果X不存在呢?可以证明**查找节点时最后一个访问的节点必定是前驱或者后继**。所以伸展后根就是X的前驱和后继之一。 当根是前驱的时候,前驱就是根,后继就是右子树的最小值; 当根是后继的时候,前驱就是左子树的最大值,后继就是根。 ```cpp int lower(int v, node *&t = root) // 前驱 { splay(v, t); if (t->val >= v) // 根是X或是X的后驱 { if (t->ch[0] == nil) return -inf; splay(inf, t->ch[0]); // 相当于伸展左子树的最大值 return t->ch[0]->val; } else return t->val; } int upper(int v, node *&t = root) // 后驱 { splay(v, t); if (t->val <= v) // 根是X或是X的前驱 { if (t->ch[1] == nil) return inf; splay(-inf, t->ch[1]); // 相当于伸展右子树的最小值 return t->ch[1]->val; } else return t->val; } ``` **不严谨的证明:** 用反证法证明。 想象一下对这棵树中序遍历得到一个有序序列。查找操作和二分查找是一样的,不过每次是以子树的根为分界点,进入左边或右边的序列(不包含该分界点)继续寻找。 由于该序列始终是连续的,若X不存在,最后序列必定会变成空的。考虑在此之前的上一步,若是**在一个小于前驱的结点**,则下一步必定是往包含前驱的方向缩小序列,故这一步不可能是最后一步(除非前驱不存在)。**在一个大于后继的结点**同理。 故经过的最后一个结点必定是前驱或后继(若存在)。 ### 求排名 伸展X成为根,求左子树的元素数量+1即可 ```cpp int getrank(int v, node *&t = root) { splay(v, t); return t->ch[0]->size + 1; } ``` ### 求K大 ```cpp int getkth(int k, node *&t = root) { splaykth(k, t); return t->val; } ``` `splaykth(k, root)`然后输出`root->val`即可。 ### 分裂 将树t分为小于等于X和大于X两部分: - 若X在树上,先伸展X,这时候树的左子树都是小于X的元素,右子树都是大于X的元素。断开根和右子树的连接即可。 - 若X不在树上,伸展操作将会把X的前驱或后继伸展至根(证明在下面)。只需判断下根是大于X还是小于X,决定断开根和左子树的连接还是右子树的连接。 ```cpp node *split(int v, node *&t) // 分裂后,树t都是小于等于X的元素,返回的树都是大于X的元素 { if (t == nil) return nil; splay(v, t); node *t1, *t2; // 用于保存分裂后的两棵树 if (t->val <= v) t1 = t, t2 = t->ch[1], t->ch[1] = nil; else t1 = t->ch[0], t2 = t, t->ch[0] = nil; t->pullup(); t = t1; return t2; } ``` ### 合并 要合并的两棵树分别为T1和T2,则必须保证树T1的最大值**严格小于**树T2的最小值。 先伸展T1的最大值节点,这时候T1的根必然没有右子树,将T2接上去即可。 ```cpp void merge(node *&t1, node *&t2) // 合并后得到的树是t1,t2会变为空树 { if (t1 == nil) swap(t1, t2); splay(inf, t1); t1->ch[1] = t2; t2 = nil; t1->pullup(); } ``` ### 插入 为什么要把插入和删除放到最后面。因为这两个操作可以通过分裂和合并优雅地实现。 先将树分裂为小于或等于X的树T1和大于X的树T2。 由于T1的根没有右子树,故T1的根就是T1的最大值。检查T1的根是否等于X:若是,说明出现重复,计数加一;否则合并T1和新节点。之后重新合并新的T1和T2。 ```cpp void insert(int v, node *&t = root) { node *t2 = split(v, t); if (t->val == v) t->cnt++; else { node *nd = new node(v); merge(t, nd); } merge(t, t2); } ``` ### 删除 先将树分裂为小于或等于X的树T1和大于X的树T2。 由于T1的根没有右子树,故T1的根就是T1的最大值。检查T1的根是否为X且计数减一后为0:若是,用T1的左子树代替T1,并删除原T1的根;否则不处理。之后重新合并新的T1和T2。 ```cpp void erase(int v, node *&t = root) { node *t2 = split(v, t); if (t->val == v && --(t->cnt) < 1) // 命中节点,计数先减一,再判断是否要将节点删除 { node *t3 = t->ch[0]; delete t; t = t3; } merge(t, t2); } ``` ## 模板 ```cpp // https://www.luogu.org/problem/show?pid=3369 // UPD: 2017/12/3 #include <iostream> using namespace std; namespace splay // 数据结构用namespace装着是个人习惯 { const int inf = 0x7fffffff; struct node *nil; // 哨兵节点,用于防止访问无效内存导致翻车 struct node { node *ch[2]; // ch[0]是左儿子指针,ch[1]是右儿子指针 int val, cnt, size; // 元素的值、元素个数(处理重复)、该节点构成的子树包含元素的个数 int cmp(int v) // 如代码所示,返回要寻找值为v的元素该向左走还是向右走 { if (v == val) return -1; else return v < val ? 0 : 1; } int cmpkth(int k) // 同上,返回要寻找第k小元素该向左走还是向右走 { if (k <= ch[0]->size) return 0; else if (k <= ch[0]->size + cnt) return -1; else return 1; } void pullup() { size = cnt + ch[0]->size + ch[1]->size; } // 用于插入或删除后重新计算size node(int v) : val(v), cnt(1), size(1) { ch[0] = ch[1] = nil; } // 普通的构造函数 } * root; void init() // 主要用来初始化哨兵节点 { nil = new node(0); root = nil->ch[0] = nil->ch[1] = nil; nil->size = nil->cnt = 0; } void rotate(node *&t, int d) //传引用很重要!! { node *k = t->ch[d ^ 1]; t->ch[d ^ 1] = k->ch[d]; k->ch[d] = t; t->pullup(), k->pullup(); // 注意此时k已经是t的父亲 t = k; } void splay(int v, node *&t) // 在树t中寻找值为v的节点,并伸展成为t的根节点;传引用很重要!! { int d = t->cmp(v); //下一步该走的方向 if (d != -1 && t->ch[d] != nil) //如果下一步可以走向一个合法结点 { int d2 = t->ch[d]->cmp(v); //下两步该走的方向 if (d2 != -1 && t->ch[d]->ch[d2] != nil) //如果下两步可以走向一个合法结点 { splay(v, t->ch[d]->ch[d2]); //先递归 if (d == d2) rotate(t, d2 ^ 1), rotate(t, d ^ 1); // zig-zig else rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); //zig-zag } else rotate(t, d ^ 1); // zig } // else t已经是终点 } void splaykth(int k, node *&t) // 同上,在树t中寻找第k小的节点,并伸展成为t的根节点;传引用很重要!! { int d = t->cmpkth(k); if (d == 1) k -= t->ch[0]->size + t->cnt; if (d != -1) { int d2 = t->ch[d]->cmpkth(k); int k2 = (d2 == 1) ? k - (t->ch[d]->ch[0]->size + t->ch[d]->cnt) : k; if (d2 != -1) { splaykth(k2, t->ch[d]->ch[d2]); if (d == d2) rotate(t, d2 ^ 1), rotate(t, d ^ 1); else rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); } else rotate(t, d ^ 1); } } // WARNING: split和merge必须要写得格外小心 node *split(int v, node *&t) // 分裂后,树t都是小于等于X的元素,返回的树都是大于X的元素 { if (t == nil) return nil; splay(v, t); node *t1, *t2; // 用于保存分裂后的两棵树 if (t->val <= v) t1 = t, t2 = t->ch[1], t->ch[1] = nil; else t1 = t->ch[0], t2 = t, t->ch[0] = nil; t->pullup(); t = t1; return t2; } void merge(node *&t1, node *&t2) // 合并后得到的树是t1,t2会变为空树 { if (t1 == nil) swap(t1, t2); splay(inf, t1); t1->ch[1] = t2; t2 = nil; t1->pullup(); } void insert(int v, node *&t = root) { node *t2 = split(v, t); if (t->val == v) t->cnt++; else { node *nd = new node(v); merge(t, nd); } merge(t, t2); } void erase(int v, node *&t = root) { node *t2 = split(v, t); if (t->val == v && --(t->cnt) < 1) // 命中节点,计数先减一,再判断是否要将节点删除 { node *t3 = t->ch[0]; delete t; t = t3; } merge(t, t2); } int getrank(int v, node *&t = root) { splay(v, t); return t->ch[0]->size + 1; } int getkth(int k, node *&t = root) { splaykth(k, t); return t->val; } int lower(int v, node *&t = root) // 前驱 { splay(v, t); if (t->val >= v) // 根是X或是X的后驱 { if (t->ch[0] == nil) return -inf; splay(inf, t->ch[0]); // 相当于伸展左子树的最大值 return t->ch[0]->val; } else return t->val; } int upper(int v, node *&t = root) // 后驱 { splay(v, t); if (t->val <= v) // 根是X或是X的前驱 { if (t->ch[1] == nil) return inf; splay(-inf, t->ch[1]); // 相当于伸展右子树的最小值 return t->ch[1]->val; } else return t->val; } } int main() { ios::sync_with_stdio(false); splay::init(); int n, opt, x; cin >> n; while (n--) { cin >> opt >> x; switch (opt) { case 1: splay::insert(x); break; case 2: splay::erase(x); break; case 3: cout << splay::getrank(x) << endl; break; case 4: cout << splay::getkth(x) << endl; break; case 5: cout << splay::lower(x) << endl; break; case 6: cout << splay::upper(x) << endl; break; } } return 0; } ```