Splay学习笔记
智子
2019-03-29 13:38:56
# 前言
伸展树(英语:Splay Tree)是一种能够**自我平衡**的二叉查找树,它能在均摊O(log n)的时间内完成基于伸展(Splay)操作的插入、查找、修改和删除操作。
# 定义
## 节点
`val`:节点node的值
`par`:节点node的父节点
`ch[2]`:节点node的左子节点与右子节点
`siz`:以节点node为根节点的子树的节点总数
`cnt`:数值与节点node相同的节点的数量(都储存在节点node中)
代码:
```cpp
int ch[MAXN][2], par[MAXN], val[MAXN], siz[MAXN], cnt[MAXN];
```
## 树
root:根节点
cnt:总结点数
```cpp
int tot, rt;
```
# 操作
## 基本操作
### pushup
pushup()函数:**更新**节点p的size值
```cpp
void pushup(int p) {
siz[p] = siz[ch[p][0]] + siz[ch[p][1]] + cnt[p];
}
```
---
### check
check()函数:询问节点p是其父节点的左子节点还是右子节点
```cpp
int chk(int p) {
return ch[par[p]][0] == p ? 0 : 1; //0代表左子节点,1代表右子节点
}
```
---
## 旋转
### rotate
旋转是平衡树最主要的操作,其本质在于,每次进行旋转时,左右子树当中之一高度 -1,另外一棵高度 +1,以达到平衡的目的。
左旋:
第一次连边,节点x的子节点成为x的父节点的右子节点
第二次连边,节点x成为节点x的父节点的父节点的子节点,方向与x的父节点相同
第三次连边,节点x的父节点成为节点x的左子节点
![](https://keepthethink.github.io/images/left_rotate.jpg)
右旋:
第一次连边,节点x的子节点成为x的父节点的左子节点
第二次连边,节点x成为节点x的父节点的父节点的子节点,方向与x的父节点相同
第三次连边,节点x的父节点成为节点x的右子节点
![](https://keepthethink.github.io/images/right_rotate.jpg)
旋转操作只与标为红,蓝,绿的三个部分有关。
```cpp
void rotate(int x) {
int y = par[x], z = par[y], d = chk(x), w = ch[x][d ^ 1]; //w判断应该左旋还是右旋
ch[y][d] = w; par[w] = y; //第一次连边,节点x的子节点连接到x的父节点,方向与节点x相同
ch[z][chk(y)] = x; par[x] = z; //第二次连边,节点x连接到节点x的父节点的父节点,方向与x的父节点相同
ch[x][d ^ 1] = y; par[y] = x; //第三次连边,节点x的父节点连接到节点x,方向与节点x原先的方向相反
pushup(y); //更新子树
pushup(x); //更新子树
}
```
## 伸展
### splay
Splay操作:将节点x旋转到节点dist的子节点。通常是将该节点旋转到根节点,在这种情况下,应当将root置为x
最朴素的想法:只要父节点不是dist就一直旋转该节点,但这样很容易被某些机(wu)智(liang)出题人卡。
```cpp
void splay(int x, int goal = 0) {
while(par[x] != goal) {
rotate(x);
}
if(goal == 0) {
rt = x;
}
}
```
所以,在实际操作中,通常会预判节点x的父节点的方向,若方向一致则旋转其父节点,减少被卡的可能性。~~多么妖娆~~
```cpp
void splay(int x, int goal = 0) {
while(par[x] != goal) {
int y = par[x], z = par[y];
if(z != goal) {
if(chk(x) == chk(y)) {
rotate(y); //方向一致则旋转x的父节点
} else {
rotate(x); //方向不一致则旋转x
}
}
rotate(x);
}
if(goal == 0) {
rt = x;
}
}
```
## 查找
### find
查找值为x的节点,找到后将其置为root以便操作。
find操作的意义在于将值为x的节点伸展(splay)到根,在不存在值为x的节点的情况下,应将小于x的节点中最大的节点伸展(splay)到根。
```cpp
void find(int v) {
int p = rt;
while(ch[p][v > val[p]] && val[p] != v) {
p = ch[p][v > val[p]];
}
splay(p);
}
```
## 公共操作
如果将本文讲的Splay打包成一个class,则前文所述的操作应包含在private中,本节所述的操作应包含在public中。
### insert
Splay中的insert其实与朴素BST中的insert没有什么区别,但若直接插入可能导致树退化为链,所以要在末尾处调用一次splay()函数,使Splay树保持平衡。
```cpp
void insert(int v) {
int cur = rt, p = 0;
while(cur && val[cur] != v) {
p = cur;
cur = ch[cur][v > val[cur]];
}
if(cur != 0) {
cnt[cur]++;
} else {
cur = ++tot;
if(p != 0) {
ch[p][v > val[p]] = cur;
}
ch[cur][0] = ch[cur][1] = 0;
par[cur] = p;
siz[cur] = cnt[cur] = 1;
val[cur] = v;
}
splay(cur);
}
```
### serial
serial操作:查询值为x的节点,在find操作的基础上,serial只需要在find过后输出左子树节点数量即可。
```cpp
find(x);
printf("%d\n", siz[ch[rt][0]]);
```
### pre
找出值为x的节点的前驱,将节点splay到root后在左子树查找最大值即可。
```cpp
int pre(int v) {
find(v);
if(val[rt] < v) {
return rt;
}
int p = ch[rt][0];
while(ch[p][1]) {
p = ch[p][1];
}
return p;
}
```
### suc
找出值为x的点的后继,与前驱同理。
```cpp
int nxt(int v) {
find(v);
if(val[rt] > v) {
return rt;
}
int p = ch[rt][1];
while(ch[p][0]) {
p = ch[p][0];
}
return p;
}
```
### remove
删除一个节点。
删除较为复杂,分四步来完成:
1. 定义last为节点的前驱,next为节点的后继。
2. 将last节点splay到root,这时last的左子树皆小于x
3. 将next节点splay到last的子节右点,此时next的右子树皆大于x
4. next的左节点rm必然满足 last < rm < next,删除rm即可
```cpp
void remove(int v) {
int last = pre(v), next = nxt(v);
splay(last);
splay(next, last);
int del = ch[next][0];
if(cnt[del] > 1) {
cnt[del]--;
splay(del);
} else {
ch[next][0] = 0;
}
pushup(next);
pushup(rt);
}
```
### rank
查找排名为k的节点
用一个指针cur从root开始查找,每次根据左子树大小于k的关系修改cur以及k。
```cpp
int kth(int k) {
int p = rt;
while(1) {
if(ch[p][0] && k <= siz[ch[p][0]]) {
p = ch[p][0];
} else if(k > siz[ch[p][0]] + cnt[p]) {
k -= siz[ch[p][0]] + cnt[p];
p = ch[p][1];
} else {
return p;
}
}
}
```
# 完整代码
```cpp
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 200000 + 5;
int ch[MAXN][2], par[MAXN], val[MAXN], siz[MAXN], cnt[MAXN];
int tot, rt;
int chk(int p) {
return ch[par[p]][0] == p ? 0 : 1;
}
void pushup(int p) {
siz[p] = siz[ch[p][0]] + siz[ch[p][1]] + cnt[p];
}
void rotate(int x) {
int y = par[x], z = par[y], d = chk(x), w = ch[x][d ^ 1];
ch[y][d] = w; par[w] = y;
ch[z][chk(y)] = x; par[x] = z;
ch[x][d ^ 1] = y; par[y] = x;
pushup(y);
pushup(x);
}
void splay(int x, int goal = 0) {
while(par[x] != goal) {
int y = par[x], z = par[y];
if(z != goal) {
if(chk(x) == chk(y)) {
rotate(y);
} else {
rotate(x);
}
}
rotate(x);
}
if(goal == 0) {
rt = x;
}
}
void insert(int v) {
int cur = rt, p = 0;
while(cur && val[cur] != v) {
p = cur;
cur = ch[cur][v > val[cur]];
}
if(cur != 0) {
cnt[cur]++;
} else {
cur = ++tot;
if(p != 0) {
ch[p][v > val[p]] = cur;
}
ch[cur][0] = ch[cur][1] = 0;
par[cur] = p;
siz[cur] = cnt[cur] = 1;
val[cur] = v;
}
splay(cur);
}
void find(int v) {
int p = rt;
while(ch[p][v > val[p]] && val[p] != v) {
p = ch[p][v > val[p]];
}
splay(p);
}
int kth(int k) {
int p = rt;
while(1) {
if(ch[p][0] && k <= siz[ch[p][0]]) {
p = ch[p][0];
} else if(k > siz[ch[p][0]] + cnt[p]) {
k -= siz[ch[p][0]] + cnt[p];
p = ch[p][1];
} else {
return p;
}
}
}
int pre(int v) {
find(v);
if(val[rt] < v) {
return rt;
}
int p = ch[rt][0];
while(ch[p][1]) {
p = ch[p][1];
}
return p;
}
int nxt(int v) {
find(v);
if(val[rt] > v) {
return rt;
}
int p = ch[rt][1];
while(ch[p][0]) {
p = ch[p][0];
}
return p;
}
void remove(int v) {
int last = pre(v), next = nxt(v);
splay(last);
splay(next, last);
int del = ch[next][0];
if(cnt[del] > 1) {
cnt[del]--;
splay(del);
} else {
ch[next][0] = 0;
}
pushup(next);
pushup(rt);
}
int main() {
int n, op, x;
scanf("%d", &n);
insert(-1e9);
insert(1e9);
for(int i = 1; i <= n; i++) {
scanf("%d%d", &op, &x);
if(op == 1) {
insert(x);
} else if(op == 2) {
remove(x);
} else if(op == 3) {
find(x);
printf("%d\n", siz[ch[rt][0]]);
} else if(op == 4) {
printf("%d\n", val[kth(x + 1)]);
} else if(op == 5) {
printf("%d\n", val[pre(x)]);
} else if(op == 6) {
printf("%d\n", val[nxt(x)]);
}
}
return 0;
}
```
# 参考资料
[伸展树- 维基百科,自由的百科全书](https://zh.wikipedia.org/zh-hans/伸展树)
[Splay Tree Introduction](https://www.youtube.com/watch?v=IBY4NtxmGg8)