题解 P3369 【【模板】普通平衡树(Treap/SBT)】
huangwenlong
2017-12-03 19:22:35
7个月前写过一篇题解,今天回来看下结果自己都看不下去,于是就来重写了。
# 递归版Splay
优点:**不用维护父指针!!!**刚开始学写非递归Splay的时候被父指针的维护坑了好久!!!
参考了大刘的《训练指南》。
## 实现
**前排警告:前方存在大量结构体+指针**
首先我用名为node的结构体保存节点:
```cpp
const int inf = 0x7fffffff;
struct node *nil; // 哨兵节点,用于防止访问无效内存导致翻车
struct node
{
node *ch[2]; // ch[0]是左儿子指针,ch[1]是右儿子指针
int val, cnt, size; // 元素的值、元素个数(处理重复)、该节点构成的子树包含元素的个数
int cmp(int v) // 如代码所示,返回要寻找值为v的元素该向左走还是向右走
{
if (v == val)
return -1;
else
return v < val ? 0 : 1;
}
int cmpkth(int k) // 同上,返回要寻找第k小元素该向左走还是向右走
{
if (k <= ch[0]->size)
return 0;
else if (k <= ch[0]->size + cnt)
return -1;
else
return 1;
}
void pullup() { size = cnt + ch[0]->size + ch[1]->size; } // 用于插入或删除后重新计算size
node(int v) : val(v), cnt(1), size(1) { ch[0] = ch[1] = nil; } // 普通的构造函数
} * root;
void init() // 主要用来初始化哨兵节点
{
nil = new node(0);
root = nil->ch[0] = nil->ch[1] = nil;
nil->size = nil->cnt = 0;
}
```
下面的说明中node既可以表示**节点**也可以表示**树**。
用0/1表示向左/向右,用`ch[0]`/`ch[1]`表示左儿子/右儿子指针,用`cmp(int v)`返回往左走/往右走,用异或运算取相反方向,这些都是来自大刘《训练指南》的技巧。因为平衡树中对称的情形太多了,合理运用这些技巧可以压缩代码量。
### 伸展
所谓递归Splay,其实就是把寻找节点和伸展节点写在了一起,把递归寻找节点展开一层以后塞几行代码调用旋转。如果找不到这个值的节点,就会伸展最后一个访问到的节点。
```cpp
void rotate(node *&t, int d) //传引用很重要!!
{
node *k = t->ch[d ^ 1];
t->ch[d ^ 1] = k->ch[d];
k->ch[d] = t;
t->pullup(), k->pullup(); // 注意此时k已经是t的父亲
t = k;
}
void splay(int v, node *&t) // 在树t中寻找值为v的节点,并伸展成为t的根节点;传引用很重要!!
{
int d = t->cmp(v); //下一步该走的方向
if (d != -1 && t->ch[d] != nil) //如果下一步可以走向一个合法结点
{
int d2 = t->ch[d]->cmp(v); //下两步该走的方向
if (d2 != -1 && t->ch[d]->ch[d2] != nil) //如果下两步可以走向一个合法结点
{
splay(v, t->ch[d]->ch[d2]); //先递归
if (d == d2)
rotate(t, d2 ^ 1), rotate(t, d ^ 1); // zig-zig
else
rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); //zig-zag
}
else
rotate(t, d ^ 1); // zig
}
// else t已经是终点
}
void splaykth(int k, node *&t) // 同上,在树t中寻找第k小的节点,并伸展成为t的根节点;传引用很重要!!
{
int d = t->cmpkth(k);
if (d == 1)
k -= t->ch[0]->size + t->cnt;
if (d != -1)
{
int d2 = t->ch[d]->cmpkth(k);
int k2 = (d2 == 1) ? k - (t->ch[d]->ch[0]->size + t->ch[d]->cnt) : k;
if (d2 != -1)
{
splaykth(k2, t->ch[d]->ch[d2]);
if (d == d2)
rotate(t, d2 ^ 1), rotate(t, d ^ 1);
else
rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1);
}
else
rotate(t, d ^ 1);
}
}
```
虽然写起来是递归,但本质还是自底向上的。百度可以找到真正的自顶向下伸展的方法。
-----
既然Splay可以变来变去,那么很多操作都有**“Splay特色”**的写法:
###求前驱/后继
逛了一圈发现都是暴力插入再求前驱/后继再删除的,下面介绍一个优雅的方法。或许是我原创的吧。
首先伸展X至根,如果X存在,根就会变成X,X的前驱就是左子树最大的值;再伸展左子树的最大值成为左子树的根,就是根节点的左儿子。求后继同理。
如果X不存在呢?可以证明**查找节点时最后一个访问的节点必定是前驱或者后继**。所以伸展后根就是X的前驱和后继之一。
当根是前驱的时候,前驱就是根,后继就是右子树的最小值;
当根是后继的时候,前驱就是左子树的最大值,后继就是根。
```cpp
int lower(int v, node *&t = root) // 前驱
{
splay(v, t);
if (t->val >= v) // 根是X或是X的后驱
{
if (t->ch[0] == nil)
return -inf;
splay(inf, t->ch[0]); // 相当于伸展左子树的最大值
return t->ch[0]->val;
}
else
return t->val;
}
int upper(int v, node *&t = root) // 后驱
{
splay(v, t);
if (t->val <= v) // 根是X或是X的前驱
{
if (t->ch[1] == nil)
return inf;
splay(-inf, t->ch[1]); // 相当于伸展右子树的最小值
return t->ch[1]->val;
}
else
return t->val;
}
```
**不严谨的证明:**
用反证法证明。
想象一下对这棵树中序遍历得到一个有序序列。查找操作和二分查找是一样的,不过每次是以子树的根为分界点,进入左边或右边的序列(不包含该分界点)继续寻找。
由于该序列始终是连续的,若X不存在,最后序列必定会变成空的。考虑在此之前的上一步,若是**在一个小于前驱的结点**,则下一步必定是往包含前驱的方向缩小序列,故这一步不可能是最后一步(除非前驱不存在)。**在一个大于后继的结点**同理。
故经过的最后一个结点必定是前驱或后继(若存在)。
### 求排名
伸展X成为根,求左子树的元素数量+1即可
```cpp
int getrank(int v, node *&t = root)
{
splay(v, t);
return t->ch[0]->size + 1;
}
```
### 求K大
```cpp
int getkth(int k, node *&t = root)
{
splaykth(k, t);
return t->val;
}
```
`splaykth(k, root)`然后输出`root->val`即可。
### 分裂
将树t分为小于等于X和大于X两部分:
- 若X在树上,先伸展X,这时候树的左子树都是小于X的元素,右子树都是大于X的元素。断开根和右子树的连接即可。
- 若X不在树上,伸展操作将会把X的前驱或后继伸展至根(证明在下面)。只需判断下根是大于X还是小于X,决定断开根和左子树的连接还是右子树的连接。
```cpp
node *split(int v, node *&t) // 分裂后,树t都是小于等于X的元素,返回的树都是大于X的元素
{
if (t == nil)
return nil;
splay(v, t);
node *t1, *t2; // 用于保存分裂后的两棵树
if (t->val <= v)
t1 = t, t2 = t->ch[1], t->ch[1] = nil;
else
t1 = t->ch[0], t2 = t, t->ch[0] = nil;
t->pullup();
t = t1;
return t2;
}
```
### 合并
要合并的两棵树分别为T1和T2,则必须保证树T1的最大值**严格小于**树T2的最小值。
先伸展T1的最大值节点,这时候T1的根必然没有右子树,将T2接上去即可。
```cpp
void merge(node *&t1, node *&t2) // 合并后得到的树是t1,t2会变为空树
{
if (t1 == nil)
swap(t1, t2);
splay(inf, t1);
t1->ch[1] = t2;
t2 = nil;
t1->pullup();
}
```
### 插入
为什么要把插入和删除放到最后面。因为这两个操作可以通过分裂和合并优雅地实现。
先将树分裂为小于或等于X的树T1和大于X的树T2。
由于T1的根没有右子树,故T1的根就是T1的最大值。检查T1的根是否等于X:若是,说明出现重复,计数加一;否则合并T1和新节点。之后重新合并新的T1和T2。
```cpp
void insert(int v, node *&t = root)
{
node *t2 = split(v, t);
if (t->val == v)
t->cnt++;
else
{
node *nd = new node(v);
merge(t, nd);
}
merge(t, t2);
}
```
### 删除
先将树分裂为小于或等于X的树T1和大于X的树T2。
由于T1的根没有右子树,故T1的根就是T1的最大值。检查T1的根是否为X且计数减一后为0:若是,用T1的左子树代替T1,并删除原T1的根;否则不处理。之后重新合并新的T1和T2。
```cpp
void erase(int v, node *&t = root)
{
node *t2 = split(v, t);
if (t->val == v && --(t->cnt) < 1) // 命中节点,计数先减一,再判断是否要将节点删除
{
node *t3 = t->ch[0];
delete t;
t = t3;
}
merge(t, t2);
}
```
## 模板
```cpp
// https://www.luogu.org/problem/show?pid=3369
// UPD: 2017/12/3
#include <iostream>
using namespace std;
namespace splay // 数据结构用namespace装着是个人习惯
{
const int inf = 0x7fffffff;
struct node *nil; // 哨兵节点,用于防止访问无效内存导致翻车
struct node
{
node *ch[2]; // ch[0]是左儿子指针,ch[1]是右儿子指针
int val, cnt, size; // 元素的值、元素个数(处理重复)、该节点构成的子树包含元素的个数
int cmp(int v) // 如代码所示,返回要寻找值为v的元素该向左走还是向右走
{
if (v == val)
return -1;
else
return v < val ? 0 : 1;
}
int cmpkth(int k) // 同上,返回要寻找第k小元素该向左走还是向右走
{
if (k <= ch[0]->size)
return 0;
else if (k <= ch[0]->size + cnt)
return -1;
else
return 1;
}
void pullup() { size = cnt + ch[0]->size + ch[1]->size; } // 用于插入或删除后重新计算size
node(int v) : val(v), cnt(1), size(1) { ch[0] = ch[1] = nil; } // 普通的构造函数
} * root;
void init() // 主要用来初始化哨兵节点
{
nil = new node(0);
root = nil->ch[0] = nil->ch[1] = nil;
nil->size = nil->cnt = 0;
}
void rotate(node *&t, int d) //传引用很重要!!
{
node *k = t->ch[d ^ 1];
t->ch[d ^ 1] = k->ch[d];
k->ch[d] = t;
t->pullup(), k->pullup(); // 注意此时k已经是t的父亲
t = k;
}
void splay(int v, node *&t) // 在树t中寻找值为v的节点,并伸展成为t的根节点;传引用很重要!!
{
int d = t->cmp(v); //下一步该走的方向
if (d != -1 && t->ch[d] != nil) //如果下一步可以走向一个合法结点
{
int d2 = t->ch[d]->cmp(v); //下两步该走的方向
if (d2 != -1 && t->ch[d]->ch[d2] != nil) //如果下两步可以走向一个合法结点
{
splay(v, t->ch[d]->ch[d2]); //先递归
if (d == d2)
rotate(t, d2 ^ 1), rotate(t, d ^ 1); // zig-zig
else
rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1); //zig-zag
}
else
rotate(t, d ^ 1); // zig
}
// else t已经是终点
}
void splaykth(int k, node *&t) // 同上,在树t中寻找第k小的节点,并伸展成为t的根节点;传引用很重要!!
{
int d = t->cmpkth(k);
if (d == 1)
k -= t->ch[0]->size + t->cnt;
if (d != -1)
{
int d2 = t->ch[d]->cmpkth(k);
int k2 = (d2 == 1) ? k - (t->ch[d]->ch[0]->size + t->ch[d]->cnt) : k;
if (d2 != -1)
{
splaykth(k2, t->ch[d]->ch[d2]);
if (d == d2)
rotate(t, d2 ^ 1), rotate(t, d ^ 1);
else
rotate(t->ch[d], d2 ^ 1), rotate(t, d ^ 1);
}
else
rotate(t, d ^ 1);
}
}
// WARNING: split和merge必须要写得格外小心
node *split(int v, node *&t) // 分裂后,树t都是小于等于X的元素,返回的树都是大于X的元素
{
if (t == nil)
return nil;
splay(v, t);
node *t1, *t2; // 用于保存分裂后的两棵树
if (t->val <= v)
t1 = t, t2 = t->ch[1], t->ch[1] = nil;
else
t1 = t->ch[0], t2 = t, t->ch[0] = nil;
t->pullup();
t = t1;
return t2;
}
void merge(node *&t1, node *&t2) // 合并后得到的树是t1,t2会变为空树
{
if (t1 == nil)
swap(t1, t2);
splay(inf, t1);
t1->ch[1] = t2;
t2 = nil;
t1->pullup();
}
void insert(int v, node *&t = root)
{
node *t2 = split(v, t);
if (t->val == v)
t->cnt++;
else
{
node *nd = new node(v);
merge(t, nd);
}
merge(t, t2);
}
void erase(int v, node *&t = root)
{
node *t2 = split(v, t);
if (t->val == v && --(t->cnt) < 1) // 命中节点,计数先减一,再判断是否要将节点删除
{
node *t3 = t->ch[0];
delete t;
t = t3;
}
merge(t, t2);
}
int getrank(int v, node *&t = root)
{
splay(v, t);
return t->ch[0]->size + 1;
}
int getkth(int k, node *&t = root)
{
splaykth(k, t);
return t->val;
}
int lower(int v, node *&t = root) // 前驱
{
splay(v, t);
if (t->val >= v) // 根是X或是X的后驱
{
if (t->ch[0] == nil)
return -inf;
splay(inf, t->ch[0]); // 相当于伸展左子树的最大值
return t->ch[0]->val;
}
else
return t->val;
}
int upper(int v, node *&t = root) // 后驱
{
splay(v, t);
if (t->val <= v) // 根是X或是X的前驱
{
if (t->ch[1] == nil)
return inf;
splay(-inf, t->ch[1]); // 相当于伸展右子树的最小值
return t->ch[1]->val;
}
else
return t->val;
}
}
int main()
{
ios::sync_with_stdio(false);
splay::init();
int n, opt, x;
cin >> n;
while (n--)
{
cin >> opt >> x;
switch (opt)
{
case 1:
splay::insert(x);
break;
case 2:
splay::erase(x);
break;
case 3:
cout << splay::getrank(x) << endl;
break;
case 4:
cout << splay::getkth(x) << endl;
break;
case 5:
cout << splay::lower(x) << endl;
break;
case 6:
cout << splay::upper(x) << endl;
break;
}
}
return 0;
}
```