主席树/可持久化线段树 学习笔记

· · 算法·理论

可能另一种阅读体验

update:

2026.4.22:被第一次打回,修改错误。\ 2026.4.24:被第二次打回,修改错误,重新提交审核,管理员辛苦了。

定义

我们引用一下 OI Wiki 的定义:

可持久化数据结构(Persistent data structure)总是可以保留每一个历史版本,并且支持操作的不可变特性(immutable)。 ——OI Wiki

故可持久化线段树就是可以存储多个历史版本,并且进行一些操作的线段树。

用途

可持久化线段树是很多可持久化数据结构的起点,它允许查询,修改某一版本的某一值,或者查询静态区间第 k 小。

值得一提的是,允许访问或修改所有版本的可持久化叫做完全可持久化。\ 详见:可持久化数据结构简介 - OI Wiki。

基础实现

对于多个版本,我们很明显可以对于每一个版本建一个线段树,但是这样,我们的空间会爆炸!

故我们考虑让前面的版本对当前版本有一定贡献,考虑对于每一个版本,有以下两种情况:

  1. 如果不需要修改节点,我们就让指针指向上一个版本的这个未修改节点。
  2. 如果是需要修改节点,我们就对其新建一个节点。\ 如图:\

其实现依赖于用结构体维护 ls, rs,并且对于一个新节点,我们先让其等于原节点,并且给其一个新的编号,再依次更新下面的节点。

代码如下:

struct node{
    int val, ls, rs;
}nd[N << 7];
#define ls(k) nd[k].ls
#define rs(k) nd[k].rs
int newnode(int p){nd[++ tot] = nd[p];return tot;}

修改操作

可见,因为每一个要修改的节点需要有一个新的编号,所以我们需要让这个应修改节点变为这个编号,其实现方法即为在 update 操作中将引入 p 编号的节点加上址传递,即加上取地址符 &。其他和普通线段树没有什么很大的差异。

代码如下:

void update(int &p, int l, int r, int x, int val)
{
    p = newnode(p);//变为新编号
    if(l == r){nd[p].val = val;return ;}//更新值
    int mid = l + r >> 1;
    if(x <= mid) update(ls(p), l, mid, x, val);
    else update(rs(p), mid + 1, r, x, val);
    //pushup(p);
     //pushup 操作看题,可有可无
}

查询操作

因为所有版本我们已经处理好了,直接查询即可。当然,你也可以使用可持久化数组。

代码如下:

ll query(int p, int l, int r, int x)
{
    if(l == r) return nd[p].val;
    int mid = l + r >> 1;
    if(x <= mid) return query(ls(p), l, mid, x);
    else return query(rs(p), mid + 1, r, x);
}

基础实现例题

1. P3919 【模板】可持久化线段树 1(可持久化数组)

简要题意:

维护一个长度为 n 的数组 a,进行 m 次以下操作:

  1. 形如 v 1 p c 表示把版本 va_p 改为 c
  2. 形如 v 2 p 输出版本 va_p

对于每个操作,都会生成一个对应的版本。

其中 1 \le n,m \le 10 ^ 61 \le p \le n-10^9 \le a_i,c \le 10^9。\ 如果当前是第 x 次操作,则 0 \le v < x。\ 版本 0 为原始数组。

思路

按照上面代码和思想,模拟即可。

:::success[代码]

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr int N = 1e6 + 66;
int root[N], tot;
struct Seg_Ment_Tree{
    struct node{
        int val, ls, rs;
    }nd[N << 7];
    #define ls(k) nd[k].ls
    #define rs(k) nd[k].rs
    int newnode(int p){nd[++ tot] = nd[p];return tot;}
    void update(int &p, int l, int r, int x, int val)
    {
        p = newnode(p);
        if(l == r){nd[p].val = val;return ;}
        int mid = l + r >> 1;
        if(x <= mid) update(ls(p), l, mid, x, val);
        else update(rs(p), mid + 1, r, x, val);
    }
    ll query(int p, int l, int r, int x)
    {
        if(l == r) return nd[p].val;
        int mid = l + r >> 1;
        if(x <= mid) return query(ls(p), l, mid, x);
        else return query(rs(p), mid + 1, r, x);
    }
}seg;
int n, m, a;
int main()
{
    ios::sync_with_stdio(false);cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for(int i = 1;i <= n;i ++) cin >> a, seg.update(root[0], 1, n, i, a);//初始化初始版本(即原数组)
    for(int i = 1;i <= m;i ++)
    {
        int od, v, p, c;
        cin >> v >> od >> p;
        root[i] = root[v];//当前版本为给定版本
        if(od == 1)
        {
            cin >> c;
            seg.update(root[i], 1, n, p, c);//修改
        }
        else cout << seg.query(root[i], 1, n, p) << '\n';
    }
    return 0;
}

:::

2. P1383 高级打字机

简要题意:

开始有一空串,要进行 n 次以下操作:

  1. 形如 T x,表示在后面加入一字符 x
  2. 形如 U x,表示撤销 x 次除操作三的操作。
  3. 形如 Q x 表示查询当前字符串的第 x 个字符。
#### 思路 板子,按照模板 1 查询即可,注意输入的字符还是数字,不然会 WA,还有注意分辨一下操作二和操作一都属于修改操作,所以需要新建一个版本,具体看代码。 :::success[代码] ```cpp #include<bits/stdc++.h> using namespace std; using ll = long long; constexpr int N = 1e5 + 66; int root[N], tot; struct Seg_Ment_Tree{ struct node{ int val, ls, rs, su; }nd[N << 7]; #define ls(k) nd[k].ls #define rs(k) nd[k].rs int newnode(int p){nd[++ tot] = nd[p];return tot;} void pushup(int p){nd[p].su = nd[ls(p)].su + nd[rs(p)].su;}//维护字符个数 void update(int &p, int l, int r, int x, int val) { p = newnode(p); if(l == r){nd[p].val = val;nd[p].su = 1;return ;} int mid = l + r >> 1; if(x <= mid) update(ls(p), l, mid, x, val); else update(rs(p), mid + 1, r, x, val); pushup(p); } ll query(int p, int l, int r, int x) { if(l == r) return nd[p].val; int mid = l + r >> 1; if(x <= mid) return query(ls(p), l, mid, x); else return query(rs(p), mid + 1, r, x); } }seg; int n, m, a, cnt; int main() { ios::sync_with_stdio(false);cin.tie(0), cout.tie(0); cin >> n; for(int i = 1;i <= n;i ++) { char c;cin >> c; if(c == 'T') { char s;cin >> s; cnt ++;//新建版本 root[cnt] = root[cnt - 1];//等于上一个版本 seg.update(root[cnt], 1, 1e5, seg.nd[root[cnt]].su + 1, s);//在上一个版本的后面加入一个字符 } else if(c == 'U') { int s;cin >> s; cnt ++;//新建版本 root[cnt] = root[cnt - 1 - s];//等于 cnt - s 的,但是我们 ++ 了,所以让 cnt - 1 } else { int s;cin >> s; cout << char(seg.query(root[cnt], 1, 1e5, s)) << '\n';//直接查询即可 } } return 0; } ``` ::: ## 进阶一点的实现 现在我们讨论静态区间第 $k$ 小。 先引入一下例题吧: ### [P3834 【模板】可持久化线段树 2](https://www.luogu.com.cn/problem/P3834) 简要题意: 给定一个长度为 $n$ 的序列 $a$,询问 $m$ 次 $a$ 中 $[l,r]$ 区间第 $k$ 小的数的值。$1 \le n,m \le 2 \times 10 ^ 5$,$0 \le a_i \le 10^9$,$1 \le l, r \le n$,$1 \le k \le r - l + 1$。 可见,我们很难在原序列中,用原本思想直接在下标 $i$ 修改,因为我们第 $k$ 小是关系大小,而在下标 $i$ 插入却无法比较大小的。 :::info[那大小我们怎么处理呢?] 考虑使用权值线段树来处理,这样我们就可以处理大小关系。 ::: :::info[那 $a_i$ 的范围太大了怎么办?] 可持久化线段树是自带动态开点性质的,我们只需关注题目的空间大小是不是过小了(比如 $125MB$),如果太小了,就离散化一下即可。 ::: ### 修改 对于每一个位置,我们考虑建立一个新版本,并在新版本所维护的权值线段树上插入它的值。 所以修改代码其实是没什么区别的,但是 $a_i$ 可能重复,所以不能直接赋值。 考虑我们还需要维护什么,以及我们需不需要 `pushup`。显然的是我们需要查询前面区间的数的个数,所以我们需要 `pushup` 来让父节点维护区间。 代码如下: ```cpp void pushup(int p){nd[p].val = nd[ls(p)].val + nd[rs(p)].val;} void update(int &p, int l, int r, int x, int val) { p = newnode(p); if(l == r){nd[p].val += val;return ;} int mid = l + r >> 1; if(x <= mid) update(ls(p), l, mid, x, val); else update(rs(p), mid + 1, r, x, val); pushup(p); } ``` ### 查询 但是我们怎么查询的?我们考虑,对于 $[l,r]$ 我们现在已经有了对应的版本,那一个 $[1,l - 1]$,$[1,r]$ 内有的数的个数的差不就是 $[l,r]$ 中数的个数吗?考虑到这个情况也可以推广到其他区间,所以我们在查询时传入两个节点,每一次判断前面的区间的数是否够 $k$,如果不够就在后面的区间里,但是此时前面是有贡献的,所以注意此时查询的 $k$ 就要减去前面的数的个数了。 注意传入的节点是 $l-1$ 和 $r$ 所对应版本即可。 对了,此时两个节点要同时跳左右儿子哦。 代码如下: ```cpp ll query(int u, int v, int l, int r, int x) { if(l == r) return l;//因为我们建的是权值线段树,此时 l 就是我们对应的值 int k = nd[ls(v)].val - nd[ls(u)].val; int mid = l + r >> 1; if(k >= x) return query(ls(u), ls(v), l, mid, x); else return query(rs(u), rs(v), mid + 1, r, x - k); } ``` :::info[全部代码] ```cpp #include<bits/stdc++.h> using namespace std; using ll = long long; constexpr int N = 1e6 + 66; int root[N], tot; struct Seg_Ment_Tree{ struct node{ int val, ls, rs; }nd[N << 7]; #define ls(k) nd[k].ls #define rs(k) nd[k].rs int newnode(int p){nd[++ tot] = nd[p];return tot;} void pushup(int p){nd[p].val = nd[ls(p)].val + nd[rs(p)].val;} void update(int &p, int l, int r, int x, int val) { p = newnode(p); if(l == r){nd[p].val += val;return ;} int mid = l + r >> 1; if(x <= mid) update(ls(p), l, mid, x, val); else update(rs(p), mid + 1, r, x, val); pushup(p); } ll query(int u, int v, int l, int r, int x) { if(l == r) return l; int k = nd[ls(v)].val - nd[ls(u)].val; int mid = l + r >> 1; if(k >= x) return query(ls(u), ls(v), l, mid, x); else return query(rs(u), rs(v), mid + 1, r, x - k); } }seg; int n, m, a; int main() { ios::sync_with_stdio(false);cin.tie(0), cout.tie(0); cin >> n >> m; for(int i = 1;i <= n;i ++) { cin >> a; root[i] = root[i - 1]; seg.update(root[i], 0, 1e9, a, 1); } for(int i = 1;i <= m;i ++) { int l, r, k;cin >> l >> r >> k; cout << seg.query(root[l - 1], root[r], 0, 1e9, k) << '\n'; } return 0; } ``` ::: ## 例题 ### 1. [P1533 可怜的狗狗](https://www.luogu.com.cn/problem/P1533) 简要题意是和上面一样的,不重新打一遍了。\ 数据范围:$1\le n \le 3\times 10^5$,$1\le m \le5\times10^4$,$0\le a_i<2^{31}$,且 $a_i$ 互不相同。 其实这个就是模板二,注意空间限制比较小,离散化一下就可以了。 :::success[代码] ```cpp #include<bits/stdc++.h> using namespace std; using ll = long long; constexpr int N = 3e5 + 66; int root[N], tot, inf = INT_MAX; struct Seg_Ment_Tree{ struct node{ int val, ls, rs; }nd[N << 5]; #define ls(k) nd[k].ls #define rs(k) nd[k].rs int newnode(int p){nd[++ tot] = nd[p];return tot;} void pushup(int p){nd[p].val = nd[ls(p)].val + nd[rs(p)].val;} void update(int &p, int l, int r, int x, int val) { p = newnode(p); if(l == r){nd[p].val += val;return ;} int mid = l + r >> 1; if(x <= mid) update(ls(p), l, mid, x, val); else update(rs(p), mid + 1, r, x, val); pushup(p); } int query(int u, int v, int l, int r, int x) { if(l == r) return l; int k = nd[ls(v)].val - nd[ls(u)].val; int mid = l + r >> 1; if(k >= x) return query(ls(u), ls(v), l, mid, x); else return query(rs(u), rs(v), mid + 1, r, x - k); } }seg; int n, m, a[N], b[N]; int main() { ios::sync_with_stdio(false);cin.tie(0), cout.tie(0); cin >> n >> m; for(int i = 1;i <= n;i ++) { cin >> a[i];b[i] = a[i]; } sort(b + 1, b + 1 + n); for(int i = 1;i <= n;i ++) a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b; for(int i = 1;i <= n;i ++) { root[i] = root[i - 1]; seg.update(root[i], 1, n, a[i], 1); } for(int i = 1;i <= m;i ++) { int l, r, k;cin >> l >> r >> k; cout << b[seg.query(root[l - 1], root[r], 1, n, k)] << '\n'; } return 0; } ``` ::: ### 2. [P3168 \[CQOI2015\] 任务查询系统](https://www.luogu.com.cn/problem/P3168) 简要题意: 有 $n$ 个任务,其统治区间为 $[l_i,r_i]$,权值为 $p_i$。 询问 $m$ 次,强制在线,询问对于 $x_i$,权值最小的 $k_i$ 的任务的权值和。 $1\le m,n \le 10 ^ 5$,$1\leq l_i\leq r_i\leq n$,$1\le p_i \le 10^7$,$x_i$ 为 $1$ 到 $n$ 的一个排列。 这不是逗我呢吗?《\ 主席树不是只能单点插入吗,怎么还要区间插入??? 等等,你别急,主席树虽然只能单点插入,但是它的版本传递有类似于**前缀和**的作用,所以我们对于 $l_i$ 和 $r_i + 1$ 进行权值线段树,但是差分就可以了。注意一个位置的数可以有许多,所以注意使用 vector 记录每个数。 查询时注意是权值和,所以我们要把左区间的贡献也算上。 :::success[代码] ```cpp #include<bits/stdc++.h> using namespace std; using ll = long long; constexpr int N = 1e5 + 66; struct Persistent_Segment_Tree{ struct node{ ll val, su; int rs, ls; }nd[N << 7]; int tot; #define ls(k) nd[k].ls #define rs(k) nd[k].rs #define val(k) nd[k].val #define su(k) nd[k].su void pushup(int p){val(p) = val(ls(p)) + val(rs(p));su(p) = su(ls(p)) + su(rs(p));} int newnode(int p){nd[++ tot] = nd[p];return tot;} void update(int &p, int l, int r, ll x, ll val) { p = newnode(p); if(l == r){val(p) += val;su(p) += x * val;return ;} int mid = l + r >> 1; if(x <= mid) update(ls(p), l, mid, x, val); else update(rs(p), mid + 1, r, x, val); pushup(p); } ll query(int p, int l, int r, ll x) { if(l == r) return min(su(p), x * l);//注意可能有多个 l,比较一下取最小即可 int mid = l + r >> 1; if(x <= val(ls(p))) return query(ls(p), l, mid, x); else return query(rs(p), mid + 1, r, x - val(ls(p))) + su(ls(p)); } }seg; ll Abs(ll x){return x > 0 ? x : -x;} ll n, m, mx; int root[N]; vector<ll> v[N]; int main() { ios::sync_with_stdio(false);cin.tie(0), cout.tie(0); cin >> m >> n; for(int i = 1;i <= m;i ++) { ll x, y, z;cin >> x >> y >> z; v[x].push_back(z), v[y + 1].push_back(-z); mx = max(mx, z); } for(int i = 1;i <= n;i ++) { root[i] = root[i - 1]; for(auto p : v[i]) seg.update(root[i], 1, mx, Abs(p), (p > 0 ? 1ll : -1ll)); } ll lst = 1; for(int i = 1;i <= n;i ++) { ll x, a, b, c;cin >> x >> a >> b >> c; a = (lst * a + b) % c + 1; lst = seg.query(root[x], 1, mx, a); cout << lst << '\n'; } return 0; } ``` ::: ## 其他可持久化结构 ### 1. 可持久化数组(非可持久化线段树实现) 一种可以访问历史版本的数组,其实现依赖于每个版本对于其之前版本有一个变量指向,对于每次修改,我们都找到其对应位置的上一次修改的位置,并且在其基础上进行修改。这样我们就不用每一次复制之前的版本,以保证了空间复杂度。 **优点**: - 空间复杂度相对可持久化线段树较小 - 好写 **缺点**: - 时间复杂度依赖于修改次数,可能会被卡 #### 定义 写法其实就是定义一个结构体,维护前驱,当前的值,当前的位置,用当前版本查询。\ 我们用可持久化数组,即可持久化线段树 1 来讲。\ 代码如下: ```cpp struct Persistent{ int val, i, pre; }; vector<Persistent> vl(N); ``` #### 更改 如果是加减乘除法,那还需要找前驱更改,但这里是覆盖,直接新建节点覆盖就可以了。\ 代码如下: ```cpp void change(int v, int p, int c) { now ++; /* int s = v; while(vl[s].i != p && s) s = vl[s].pre; if(s == 0) 如果是加减乘除等的话,用得到这里 */ vl[now].val = c;vl[now].i = p;vl[now].pre = v; } ``` #### 查询 按照找前驱的思想就可以了,生成一个完全相同的就是不更改原节点的值,然后指向 $v$ 就可以了。注意如果没找到就是原数组里的值。\ 代码如下: ```cpp int query(int v, int p) { now ++; int s = v; while(vl[s].i != p && s) s = vl[s].pre; vl[now].i = p;vl[now].pre = v; //题目中要求生成一个完全一样的版本 if(s == 0) return vl[now].val = a[p]; return vl[now].val = vl[s].val; } ``` :::success[完整代码] ```cpp #include<bits/stdc++.h> using namespace std; using ll = long long; constexpr int N = 1e6 + 66; struct Persistent{ int val, i, pre; }; vector<Persistent> vl(N); int n, m, a[N], now; template<typename T> void read(T&x) { char c;int sign = 1;x = 0; // 判断是否是负数,以及去掉前面多余的字符 do { c = getchar(); if(c == '-') sign = -1; }while(!isdigit(c)); do { x = x * 10 + c - '0'; c = getchar(); }while(isdigit(c)); x *= sign; } //a 是原数组,因为我的写法没有方法存 a //now 是当前版本 void change(int v, int p, int c) { now ++; /* int s = v; while(vl[s].i != p && s) s = vl[s].pre; if(s == 0) 如果是加的话,用得到这里 */ vl[now].val = c;vl[now].i = p;vl[now].pre = v; } int query(int v, int p) { now ++; int s = v; while(vl[s].i != p && s) s = vl[s].pre; vl[now].i = p;vl[now].pre = v; //题目中要求生成一个完全一样的版本 if(s == 0) return vl[now].val = a[p]; return vl[now].val = vl[s].val; } int main() { ios::sync_with_stdio(false);cin.tie(0), cout.tie(0); read(n), read(m); for(int i = 1;i <= n;i ++) read(a[i]); for(int i = 1;i <= m;i ++) { int v, od, p, c; read(v), read(od), read(p); if(od == 1) { read(c); change(v, p, c); } else cout << query(v, p) << '\n'; } return 0; } ``` ::: ### 2. 可持久化并查集 它其实就是基于可持久化线段树维护的并查集每个节点的 $fa$,这样就可以保证其可持久化。 而其唯一一个区别是我们不可以用路径压缩,因为其的时间复杂度是均摊单次 $O(\alpha(n))$ 的。这就导致可能出题人出个卡路径压缩的数据给我们单次操作卡到 $O(n)$ 那不炸了吗。 故我们需要一个严格 $O(\log n)$ 的优化方式,即按秩合并。 按秩合并有多种,比如按深度,按子树大小,~~还有随机化(会被卡)~~。注意此时深度表示的是整棵树的最深节点的深度。\ 其实现就是把深度或者子树大小小的并查集的父亲设为大的并查集。\ 前面两种的时间复杂度为严格 $O(\log n)$ 但是随机化是可能退化到 $O(n)$ 的,具体证明可以参考[可持久化并查集的题解](https://www.luogu.com.cn/article/dc4docp3),这里不过多赘述了。 那我们在每个版本不仅要维护一个 $fa$ 还要维护一个 $sz$(子树大小) 或者 $dep$(深度),但对于每个版本我们都可以同时更新这几个,建两棵可持久化线段树即可。 于是我们的 `find` 函数就变为了: ```cpp int find(int x) { while(fa.query(fa.root[now], 1, n, x) != x) x = fa.query(fa.root[now], 1, n, x); //在当前的版本查找到一个并查集的根,暴力跳即可 return x; } ``` `merge` 就变为了: ```cpp void merge(int x, int y) { x = find(x), y = find(y);//找根 if(x == y) return ; int X = siz.query(siz.root[now], 1, n, x), Y = siz.query(siz.root[now], 1, n, y); //我写的是子树大小的,查询当前版本以 x 和 y 为根的子树大小 if(X <= Y)//小的合并到大的 { fa.update(fa.root[now], 1, n, x, y); siz.update(siz.root[now], 1, n, y, X + Y); } else { fa.update(fa.root[now], 1, n, y, x); siz.update(siz.root[now], 1, n, x, X + Y); } } ``` 对于每一个新版本我们注意继承上一个版本,如果有回溯版本就再另外赋值成那个版本就可以了。 :::success[代码] ```cpp #include<bits/stdc++.h> using namespace std; using ll = long long; constexpr int N = 2e5 + 66; struct segmenttree{ int root[N]; struct node{ int val, ls, rs; }nd[N << 7]; #define val(k) nd[k].val #define ls(k) nd[k].ls #define rs(k) nd[k].rs int tot = 0; int newnode(int p){nd[++ tot] = nd[p];return tot;} void build(int &p, int l, int r, int a[]) { p = ++ tot; if(l == r) {val(p) = a[l];return ;} int mid = l + r >> 1; build(ls(p), l, mid, a); build(rs(p), mid + 1, r, a); } void update(int &p, int l, int r, int x, int val) { p = newnode(p); if(l == r) {val(p) = val;return ;} int mid = l + r >> 1; if(x <= mid) update(ls(p), l, mid, x, val); else update(rs(p), mid + 1, r, x, val); } int query(int p, int l, int r, int x) { if(l == r) return val(p); int mid = l + r >> 1; if(x <= mid) return query(ls(p), l, mid, x); else return query(rs(p), mid + 1, r, x); } }fa, siz; int now; int n, m; int find(int x) { while(fa.query(fa.root[now], 1, n, x) != x) x = fa.query(fa.root[now], 1, n, x); return x; } void merge(int x, int y) { x = find(x), y = find(y); if(x == y) return ; int X = siz.query(siz.root[now], 1, n, x), Y = siz.query(siz.root[now], 1, n, y); if(X <= Y) { fa.update(fa.root[now], 1, n, x, y); siz.update(siz.root[now], 1, n, y, X + Y); } else { fa.update(fa.root[now], 1, n, y, x); siz.update(siz.root[now], 1, n, x, X + Y); } } int main() { ios::sync_with_stdio(false);cin.tie(0), cout.tie(0); cin >> n >> m; for(int i = 1;i <= n;i ++) fa.update(fa.root[0], 1, n, i, i), siz.update(siz.root[0], 1, n, i, 1);//初始化 while(m --) { now ++; fa.root[now] = fa.root[now - 1], siz.root[now] = siz.root[now - 1];//继承上一个版本 int op, x, y;cin >> op; if(op == 1) cin >> x >> y, merge(x, y); else if(op == 3) cin >> x >> y, cout << (find(x) == find(y)) << '\n'; else cin >> x, fa.root[now] = fa.root[x], siz.root[now] = siz.root[x]; } return 0; } ``` :::