P3369普通平衡树题解

· · 题解

洛谷P3369

刚学了基础的平衡树,写篇题解巩固一下 题目传送门

解法说明

首先读一下题。嗯……明显的一道平衡树裸题。
因此我们可以用无旋 Treap 来做。
它相当于二叉搜索树加上堆。
我们要知道,无旋 Treap 的精髓是分裂和合并两个函数。
所有操作都是基于这两个操作之上。
先来看一下分裂与合并。

void split(int u, int & l, int & r, int val) { // 分裂
    if (!u) { // 分到底了
        l = r = 0;
        return;
    } 
    if (node[u].val <= val) { // 注意这里左子树是包括当前顶点的
        l = u;
        split(node[u].r, node[u].r, r, val);
    } else {
        r = u;
        split(node[u].l, l, node[u].l, val);
    }
    pushup(u); // 上传
    return;
}
int merge(int l, int r) { // 合并
    if (!l || !r) {
        return l + r;
    }
    if (node[l].pty <= node[r].pty) { // 按优先级决定
        node[l].r = merge(node[l].r, r);
        pushup(l); // 上传
        return l; // 返回根节点
    } else {
        node[r].l = merge(l, node[r].l);
        pushup(r);
        return r;
    }
}

忘了,这题要用到集合元素个数,
所以 pushup 肯定不能少啦~

void pushup(int x) {
    node[x].sz = node[node[x].l].sz + node[node[x].r].sz + 1;
    return;
}

看一下操作一:

M 中插入一个数 x

可以把 M 按照 x 的值分裂成 \le x>x 两个集合,把 x 插入到根。


void add(int val) { // 创建一个新的节点
num++;
node[num].val = val;
node[num].sz = 1; // 所在集合元素总数+1
node[num].pty = rand(); // 随机优先值,防止堆退化成链
node[num].l = 0; // 左右儿子赋初值
node[num].r = 0;
return;
}

void ins(int val) { split(root, dl, dr, val); // 按val分裂 add(val); root = merge(merge(dl, num), dr); // 合并 }

删除也一样:
```cpp
void del(int val) {
    split(root, dl, dr, val); // 分裂
    split(dl, dl, temp, val - 1);
    temp = merge(node[temp].l, node[temp].r); // 合并
    root = merge(merge(dl, temp), dr); // 因为根改变了,所以重新赋值
}

再看操作三:

查询 M 中有多少个数比 x 小,并且将得到的答案加一。

其实就是求排名。
访问左右儿子,因为 Treap 满足:

l<mid<r

所以左子树的元素个数 +1 即为答案。

void rnk(int val) {
split(root, dl, dr, val - 1); // 按x-1进行分裂
cout << node[dl].sz + 1 << endl; // 包括自己
root = merge(dl, dr);
}

操作四按左右子树递归寻找即可。

int kth(int u, int p) { // 注意返回的是结点编号
if (p <= node[node[u].l].sz) { // 在左子树中
return kth(node[u].l, p);
} else if (p == node[node[u].l].sz + 1) { // 自己是根节点
return u;
} else { // 访问右子树
p -= node[node[u].l].sz + 1;
return kth(node[u].r, p);
}
}

操作五六本质上是一个东西。话不多说,上代码:


void pre(int val) { // 找前驱
split(root, dl, dr, val - 1); // 由于前驱一定比当前根结点小,所以访问左子树
cout << node[kth(dl, node[dl].sz)].val << endl; // 前驱是左子树中最大的
root = merge(dl, dr); // 重新赋值
}

void nxt(int val) { // 找后继 split(root, dl, dr, val); // 由于后继一定比当前根结点大,所以访问右子树 cout << node[kth(dr, 1)].val << endl; // 后继是右子树中最小的 root = merge(dl, dr); // 重新赋值 }

完整代码:
```cpp
#include <bits/stdc++.h>

using namespace std;

const int N = 1000000 + 5;

struct tree {
    int l, r, pty, val, sz;
} node[N]; 

int n, num, root, dl, dr, temp;

void pushup(int x) {
    node[x].sz = node[node[x].l].sz + node[node[x].r].sz + 1;
    return;
}

void split(int u, int & l, int & r, int val) {
    if (!u) {
        l = r = 0;
        return;
    } 
    if (node[u].val <= val) {
        l = u;
        split(node[u].r, node[u].r, r, val);
    } else {
        r = u;
        split(node[u].l, l, node[u].l, val);
    }
    pushup(u);
    return;
}

void add(int val) {
    num++;
    node[num].val = val;
    node[num].sz = 1;
    node[num].pty = rand();
    node[num].l = 0;
    node[num].r = 0;
    return;
}

int merge(int l, int r) {
    if (!l || !r) {
        return l + r;
    }
    if (node[l].pty <= node[r].pty) {
        node[l].r = merge(node[l].r, r);
        pushup(l);
        return l;
    } else {
        node[r].l = merge(l, node[r].l);
        pushup(r);
        return r;
    }
}

int kth(int u, int p) {
    if (p <= node[node[u].l].sz) {
        return kth(node[u].l, p);
    } else if (p == node[node[u].l].sz + 1) {
        return u;
    } else {
        p -= node[node[u].l].sz + 1;
        return kth(node[u].r, p);
    }
}

int main(void) {
    cin >> n;
    srand(time(NULL));
    int opt, x;
    for (int i = 1; i <= n; i++) {
        cin >> opt >> x;
        if (opt == 1){
            split(root, dl, dr, x);
            add(x);
            root = merge(merge(dl, num), dr);
        } else if (opt == 2){
            split(root, dl, dr, x);
            split(dl, dl, temp, x - 1);
            temp = merge(node[temp].l, node[temp].r);
            root = merge(merge(dl, temp), dr);
        } else if (opt == 3){
            split(root, dl, dr, x - 1);
            cout << node[dl].sz + 1 << endl;
            root = merge(dl, dr);
        } else if (opt == 4){
            cout << node[kth(root, x)].val << endl;
        } else if (opt == 5){
            split(root, dl, dr, x - 1);
            cout << node[kth(dl, node[dl].sz)].val << endl;
            root = merge(dl, dr);
        } else {
            split(root, dl, dr, x);
            cout << node[kth(dr, 1)].val << endl;
            root = merge(dl, dr);
        }
    } 
    return 0;
}

完结撒花~~~
对于代码讲解有错误等,欢迎在评论区提出~