权值线段树(动态开点)
实现一个容器,支持以下操作:
- 插入
x 数 - 删除
x 数(若有多个相同的数,因只删除一个) - 查询
x 数的排名(排名定义为比当前数小的数的个数+1 ) - 查询排名为
x 的数 - 求
x 的前驱(前驱定义为小于x ,且最大的数) - 求
x 的后继(后继定义为大于x ,且最小的数)
(虽然这玩意是平衡树)
(为什么总是要学一些比正解弱的东西啊)
(甚至还用二分树状数组写...)
用权值线段树,节点权值就是对应区间里的数的个数
动态开点避免了离散化(同时也支持在线),不过不太清楚数值区间、数据规模和数组大小的关系...
感觉更重要的还是 动态开点 这件事...
实现的话,就是记录一下左右节点的位置,不再是 x<<1 和 x<<1|1 了。
#include <cstdio>
#include <cstring>
using namespace std;
const int MAXN = 3000005;
const int L = -1e7 - 5;
const int R = 1e7 + 5;
int T;
struct segmentTree {
#define lson ls[x]
#define rson rs[x]
int tot;
int cnt[MAXN], ls[MAXN], rs[MAXN];
segmentTree()
{
tot = 1; // root is 1
memset(cnt, 0, sizeof(cnt));
memset(ls, 0, sizeof(ls));
memset(rs, 0, sizeof(rs));
}
void pushup(int x)
{
cnt[x] = cnt[lson] + cnt[rson];
}
void pushdown(int x)
{
if (!lson) lson = ++tot;
if (!rson) rson = ++tot;
}
// push or pop p
void add(int x, int l, int r, int p, int k)
{
if (l == r) {
cnt[x] += k;
} else {
int mid = (l + r) >> 1;
pushdown(x);
if (p <= mid) add(lson, l, mid, p, k);
else add(rson, mid+1, r, p, k);
pushup(x);
}
}
// counting numbers in [_l, _r]
int query_cnt(int x, int l, int r, int _l, int _r)
{
if (l >=_l && r <=_r) return cnt[x];
int tmp = 0;
int mid = (l + r) >> 1;
pushdown(x);
if (mid >=_l) tmp += query_cnt(lson, l, mid, _l, _r);
if (mid < _r) tmp += query_cnt(rson, mid+1, r, _l, _r);
return tmp;
}
// the number ranking k
int query_number(int x, int l, int r, int k)
{
if (l == r) return l;
int mid = (l + r) >> 1;
if (cnt[lson] >= k) return query_number(lson, l, mid, k);
else return query_number(rson, mid+1, r, k-cnt[lson]);
}
// rank of k
int query_rank(int k)
{
return query_cnt(1, L, R, L, k-1) + 1;
}
// pre of k
int query_prev(int k)
{
int rank = query_cnt(1, L, R, L, k-1);
return query_number(1, L, R, rank);
}
// nxt of k
int query_next(int k)
{
int rank = query_cnt(1, L, R, L, k) + 1;
return query_number(1, L, R, rank);
}
} ST;
int read()
{
register int o = 0, oo = 0;
register char c = getchar();
while (c < '0' || c > '9') oo |= (c == '-'), c = getchar();
while (c >='0' && c <='9') o = (o<<3)+(o<<1)+(c&15), c = getchar();
return oo ? -o : o;
}
int main()
{
T = read();
for (; T; T--) {
int opt = read(), k = read();
if (opt == 1) ST.add(1, L, R, k, 1);
if (opt == 2) ST.add(1, L, R, k, -1);
if (opt == 3) printf("%d\n", ST.query_rank(k));
if (opt == 4) printf("%d\n", ST.query_number(1, L, R, k));
if (opt == 5) printf("%d\n", ST.query_prev(k));
if (opt == 6) printf("%d\n", ST.query_next(k));
}
}
模板