【模板】可持久化线段树(主席树)
0AND1STORY
2019-11-11 13:06:38
我用class封装了一个主席树的模板,代码如下:
模板题:[P3919 【模板】可持久化数组(可持久化线段树/平衡树)](https://www.luogu.org/problem/P3919)
```cpp
#include <cstdio>
using namespace std;
template<typename T, size_t _size>
class PresidentTree {
private:
typedef
class Node {
public:
int l, r, v;
} *lpNode;
size_t sz, history_sz;
size_t* root;
lpNode tree;
public:
PresidentTree(): sz(0), history_sz(0), tree(NULL), root(NULL) {
tree = new Node[_size];
root = new size_t[_size];
}
private:
inline size_t _clone(const size_t& node) {
tree[++sz] = tree[node];
return sz;
}
template<typename ArrayType>
inline size_t _build(const ArrayType& a, const size_t& l, const size_t& r) {
register size_t node = ++sz;
if (l == r) {
tree[node].v = a[l];
} else {
register size_t mid = l + r >> 1;
tree[node].l = _build(a, l, mid);
tree[node].r = _build(a, mid+1, r);
}
return node;
}
inline size_t _update(register size_t node, const size_t& l, const size_t& r, const size_t& k, const T& v) {
node = _clone(node);
if (l == r) {
tree[node].v = v;
} else {
register size_t mid = l + r >> 1;
if (k <= mid) tree[node].l = _update(tree[node].l, l, mid, k, v);
else tree[node].r = _update(tree[node].r, mid+1, r, k, v);
}
return node;
}
inline T _query(const size_t& node, const size_t& l, const size_t& r, const size_t& k) {
if (l == r) {
return tree[node].v;
} else {
register size_t mid = l + r >> 1;
if (k <= mid) return _query(tree[node].l, l, mid, k);
else return _query(tree[node].r, mid+1, r, k);
}
}
public:
template<typename ArrayType>
inline size_t build(const ArrayType& a, const size_t& l, const size_t& r) {
return root[0] = _build(a, l, r);
}
inline size_t update(const size_t& rt, const size_t& l, const size_t& r, const size_t& pos, const T& val) {
return root[++history_sz] = _update(root[rt], l, r, pos, val);
}
inline T query(const size_t& rt, const size_t& l, const size_t& r, const size_t& pos) {
root[++history_sz] = root[rt];
return _query(root[rt], l, r, pos);
}
};
const size_t maxn = 2e7+5;
int n, m;
int a[maxn];
PresidentTree<int, maxn> t;
int main() {
scanf("%d%d", &n, &m);
for (register int i = 1; i <= n; i ++) scanf("%d", &a[i]);
t.build(a, 1, n);
for (register int i = 1, rt, mode, pos, val; i <= m; i ++) {
scanf("%d%d%d", &rt, &mode, &pos);
if (mode == 1) {
scanf("%d", &val);
t.update(rt, 1, n, pos, val);
} else {
printf("%d\n", t.query(rt, 1, n, pos));
}
}
return 0;
}
```