可持久化线段树学习笔记

· · 个人记录

一、可持久化线段树

有时候我们会遇到这样的题目:还是普通的线段树操作,但要求在历史版本上进行操作。你也许会一头雾水,但不用担心,解决它的方法就是今天讲的可持久化线段树。

它的建树和动态开点线段树一样,这里就不再讲了。

Code:

void build (int &o , int l , int r) {
    if (!o) o = ++cnt ;
    if (l == r) {
        t[o].sum = a[l] ;
        return ;
    }
    int mid = l + r >> 1 ;
    build (t[o].lc , l , mid) ;
    build (t[o].rc , mid + 1 , r) ;
    push_up (o) ; 
}

二、修改

首先,有一个很暴力的思想就是每次操作时把整棵树复制一份,建立一个新的版本,但这样的空间复杂度是 \Theta (n^2) 的,和暴力没什么区别。

但我们知道,无论是单点修改还是区间修改,每次最多修改 \Theta(\log n) 个节点。因此我们可以像下面一样建立新版本。

以单点修改为例,步骤如下图:

假如我们要在版本 0 的基础上,修改 4 号节点:

首先,我们将版本 0 的根节点拷贝一份,将它的左右儿子设为 rt_0 的左右儿子,如下图:

紧接着,我们发现要修改它的左儿子,于是我们把原来版本的左儿子拷贝一份,将 rt_1 的左儿子指向新建的节点,如下图:

然后我们发现要修改右儿子,于是像上面一样操作即可:

最后,我们到达了叶子节点,修改完成!

Code:

void upd (int &o , int l , int r , int x , int k) { 
    t[++cnt].lc = t[o].lc ;
    t[cnt].rc = t[o].rc ;
    t[cnt].sum = t[o].sum + k ; 
    o = cnt ;
    if (l == r) return ;
    int mid = l + r >> 1 ;
    if (x <= mid) upd (t[o].lc , l , mid , x , k) ;
    else upd (t[o].rc , mid + 1 , r , x , k) ;
}

对于区间修改,如果我们每次都进行 push_down 操作的话,就会新建许多没用的节点,空间会爆炸。于是,我们保留 lazy 标记不下传,在查询时再访问即可。

Code:

void upd (int &o , int l , int r , int x , int y , ll k) {
    t[++cnt].lc = t[o].lc ; t[cnt].rc = t[o].rc ;
    t[cnt].sum = t[o].sum ; t[cnt].lz = t[o].lz ;
    o = cnt ;
    t[o].sum += k * (min (r , y) - max (l , x) + 1LL) ;
    if (x <= l && r <= y) {
        t[o].lz += k ;
        return ;
    }
    int mid = l + r >> 1 ;
    if (x <= mid) upd (t[o].lc , l , mid , x , y , k) ;
    if (mid < y) upd (t[o].rc , mid + 1 , r , x , y , k) ;
}

三、查询

和普通的线段树一样,但每次要加上已有的 lazy 标记。

Code:

ll query (int o , int l , int r , int x , int y) {
    if (x <= l && r <= y) return t[o].sum ;
    ll ret = (min (r , y) - max (l , x) + 1) * t[o].lz ;
    int mid = l + r >> 1 ;
    if (x <= mid) ret += query (t[o].lc , l , mid , x , y) ;
    if (mid < y) ret += query (t[o].rc , mid + 1 , r , x , y) ;
    return ret ; 
}

四、例题

给定一个序列,每次求区间 [l,r] 内的第 k 小值。

我们只要将序列离散化,依次加入,把第 i 个加入的数看作第 i 个版本,在权值上建立可持久化线段树即可。我们会发现,第 i 个版本对应的是 [1,i] 中加入的数,因此我们利用前缀和思想,就可以得到 [l,r] 中加入的数了。

Code:

#include <cstdio>
#include <algorithm>
using namespace std ;
const int MAXN = 5e5 + 10 ;
int n , m , a[MAXN] , b[MAXN] , rt[MAXN] , cnt ;
struct sgt {
    int lc , rc , sum ;
    sgt () {lc = rc = sum = 0 ;}
} t[MAXN * 25] ;
void push_up (int o) {
    t[o].sum = t[t[o].lc].sum + t[t[o].rc].sum ;
}
void build (int &o , int l , int r) {
    if (!o) o = ++cnt ;
    if (l == r) return ;
    int mid = l + r >> 1 ;
    build (t[o].lc , l , mid) ;
    build (t[o].rc , mid + 1 , r) ;
}
void upd (int &o , int l , int r , int x) {
    t[++cnt].lc = t[o].lc ;
    t[cnt].rc = t[o].rc ;
    t[cnt].sum = t[o].sum + 1 ;
    o = cnt ;
    if (l == r) return ;
    int mid = l + r >> 1 ;
    if (x <= mid) upd (t[o].lc , l , mid , x) ;
    else upd (t[o].rc , mid + 1 , r , x) ;
}
int query (int o1 , int o2 , int l , int r , int k) {
    if (l == r) return a[l] ;
    int mid = l + r >> 1 , tmp = t[t[o2].lc].sum - t[t[o1].lc].sum ;
    if (tmp >= k) return query (t[o1].lc , t[o2].lc , l , mid , k) ;
    else return query (t[o1].rc , t[o2].rc , mid + 1 , r , k - tmp) ;
}
int main () {
    scanf ("%d %d" , &n , &m) ;
    for (int i = 1 ; i <= n ; i++)
        scanf ("%d" , &a[i]) , b[i] = a[i] ;
    sort (a + 1 , a + n + 1) ;
    int k = unique (a + 1 , a + n + 1) - a - 1 ;
    for (int i = 1 ; i <= n ; i++)
        b[i] = lower_bound (a + 1 , a + k + 1 , b[i]) - a ;
    build (rt[0] , 1 , n) ;
    for (int i = 1 ; i <= n ; i++) {
        rt[i] = rt[i - 1] ;
        upd (rt[i] , 1 , n , b[i]) ;
    }
    while (m--) {
        int l , r , k ;
        scanf ("%d %d %d" , &l , &r , &k) ;
        printf ("%d\n" , query (rt[l - 1] , rt[r] , 1 , n , k)) ;
    }
    return 0 ;
}