可持久化线段树学习笔记
一、可持久化线段树
有时候我们会遇到这样的题目:还是普通的线段树操作,但要求在历史版本上进行操作。你也许会一头雾水,但不用担心,解决它的方法就是今天讲的可持久化线段树。
它的建树和动态开点线段树一样,这里就不再讲了。
Code:
void build (int &o , int l , int r) {
if (!o) o = ++cnt ;
if (l == r) {
t[o].sum = a[l] ;
return ;
}
int mid = l + r >> 1 ;
build (t[o].lc , l , mid) ;
build (t[o].rc , mid + 1 , r) ;
push_up (o) ;
}
二、修改
首先,有一个很暴力的思想就是每次操作时把整棵树复制一份,建立一个新的版本,但这样的空间复杂度是
但我们知道,无论是单点修改还是区间修改,每次最多修改
以单点修改为例,步骤如下图:
假如我们要在版本
首先,我们将版本
紧接着,我们发现要修改它的左儿子,于是我们把原来版本的左儿子拷贝一份,将
然后我们发现要修改右儿子,于是像上面一样操作即可:
最后,我们到达了叶子节点,修改完成!
Code:
void upd (int &o , int l , int r , int x , int k) {
t[++cnt].lc = t[o].lc ;
t[cnt].rc = t[o].rc ;
t[cnt].sum = t[o].sum + k ;
o = cnt ;
if (l == r) return ;
int mid = l + r >> 1 ;
if (x <= mid) upd (t[o].lc , l , mid , x , k) ;
else upd (t[o].rc , mid + 1 , r , x , k) ;
}
对于区间修改,如果我们每次都进行 push_down 操作的话,就会新建许多没用的节点,空间会爆炸。于是,我们保留
Code:
void upd (int &o , int l , int r , int x , int y , ll k) {
t[++cnt].lc = t[o].lc ; t[cnt].rc = t[o].rc ;
t[cnt].sum = t[o].sum ; t[cnt].lz = t[o].lz ;
o = cnt ;
t[o].sum += k * (min (r , y) - max (l , x) + 1LL) ;
if (x <= l && r <= y) {
t[o].lz += k ;
return ;
}
int mid = l + r >> 1 ;
if (x <= mid) upd (t[o].lc , l , mid , x , y , k) ;
if (mid < y) upd (t[o].rc , mid + 1 , r , x , y , k) ;
}
三、查询
和普通的线段树一样,但每次要加上已有的
Code:
ll query (int o , int l , int r , int x , int y) {
if (x <= l && r <= y) return t[o].sum ;
ll ret = (min (r , y) - max (l , x) + 1) * t[o].lz ;
int mid = l + r >> 1 ;
if (x <= mid) ret += query (t[o].lc , l , mid , x , y) ;
if (mid < y) ret += query (t[o].rc , mid + 1 , r , x , y) ;
return ret ;
}
四、例题
给定一个序列,每次求区间
我们只要将序列离散化,依次加入,把第
Code:
#include <cstdio>
#include <algorithm>
using namespace std ;
const int MAXN = 5e5 + 10 ;
int n , m , a[MAXN] , b[MAXN] , rt[MAXN] , cnt ;
struct sgt {
int lc , rc , sum ;
sgt () {lc = rc = sum = 0 ;}
} t[MAXN * 25] ;
void push_up (int o) {
t[o].sum = t[t[o].lc].sum + t[t[o].rc].sum ;
}
void build (int &o , int l , int r) {
if (!o) o = ++cnt ;
if (l == r) return ;
int mid = l + r >> 1 ;
build (t[o].lc , l , mid) ;
build (t[o].rc , mid + 1 , r) ;
}
void upd (int &o , int l , int r , int x) {
t[++cnt].lc = t[o].lc ;
t[cnt].rc = t[o].rc ;
t[cnt].sum = t[o].sum + 1 ;
o = cnt ;
if (l == r) return ;
int mid = l + r >> 1 ;
if (x <= mid) upd (t[o].lc , l , mid , x) ;
else upd (t[o].rc , mid + 1 , r , x) ;
}
int query (int o1 , int o2 , int l , int r , int k) {
if (l == r) return a[l] ;
int mid = l + r >> 1 , tmp = t[t[o2].lc].sum - t[t[o1].lc].sum ;
if (tmp >= k) return query (t[o1].lc , t[o2].lc , l , mid , k) ;
else return query (t[o1].rc , t[o2].rc , mid + 1 , r , k - tmp) ;
}
int main () {
scanf ("%d %d" , &n , &m) ;
for (int i = 1 ; i <= n ; i++)
scanf ("%d" , &a[i]) , b[i] = a[i] ;
sort (a + 1 , a + n + 1) ;
int k = unique (a + 1 , a + n + 1) - a - 1 ;
for (int i = 1 ; i <= n ; i++)
b[i] = lower_bound (a + 1 , a + k + 1 , b[i]) - a ;
build (rt[0] , 1 , n) ;
for (int i = 1 ; i <= n ; i++) {
rt[i] = rt[i - 1] ;
upd (rt[i] , 1 , n , b[i]) ;
}
while (m--) {
int l , r , k ;
scanf ("%d %d %d" , &l , &r , &k) ;
printf ("%d\n" , query (rt[l - 1] , rt[r] , 1 , n , k)) ;
}
return 0 ;
}