主席树

· · 算法·理论

概念:

主席树的全称为可持久化权值线段树,用于维护每一个数在插入线段树后的历史版本,从而进行区间 k 大数查询等操作

基本形态:

如果为每一历史版本新开一棵线段树,空间复杂度将达到 \mathcal O(n^2),是不可接受的

而对于一棵线段树,每一次操作仅会对树上的一条链产生影响,所以每一次新建一条链,再将其余不在链上的点用指针指向即可

第一版本:

第二版本:

第三版本:

写法:

首先建立第一个版本的线段树,然后用链的形式向线段树中插入数字,形成一棵完整的主席树

1.build(l, r)

建立空的线段树作为初始版本

主席树每一个节点可以有多余两个节点的情况,所以不能用堆式建树,只能以动态开点的形式进行建树

该步有时不必要,尤其是值域太大且不能离散化时

1.定义结构体:

struct Node{
    int l, r;
    int cnt;
} tr[M];      // 可持久化数据结构有内存占用大的通病,数组一定开足,能开多大开多大!!!

2.建树函数:

int build(int l, int r)
{
    int p = ++ idx;                                       //插入新节点
    if(l == r) return p;

    int mid = l + r >> 1;
    tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r); //建立左右两子树,与传统线段树相似
    return p;
}

2.update(p, l, r, x)

建立每一个历史版本

### $code:
int update(int p, int l, int r, int x)
{
    int q = ++ idx;
    tr[q] = tr[p];                                      // q 是 p 的复制,直接将其信息进行复制

    if(l == r)
    {
        { 维护所需信息 };
        return q;
    }

    int mid = l + r >> 1;
    if(x <= mid) tr[q].l = update(tr[p].l, l, mid, x); // 建立左子树
    else tr[q].r = update(tr[p].r, mid + 1, r, x);     // 建立右子树
        // 这两步操作时,q 与 p 一定为同步的节点,同时向左或向右即可
    pushup(q);                                         // 将左右子树的信息整合到父节点上
    return q;
}

query(q, p, l, r, k)

以区间 k 大数查询为例

代码释义详见[这里](https://www.acwing.com/activity/content/code/content/7585434/) ### $code:
int query(int q, int p, int l, int r, int k)
{
    if(l == r) return r;
    int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt; 
    int mid = l + r >> 1;
    if(k <= cnt) return query(tr[q].l, tr[p].l, l, mid, k);
    else return query(tr[q].r, tr[p].r, mid + 1, r, k - cnt);
}

Attention:

  1. 使用主席树时大部分情况需要离散化使值域变小,离散后值域长度为离散化数组的大小
  2. 注意 updatequery 时传入的是每一个版本的根

例题:

1.P3834 [模板] 可持久化线段树 2

裸的区间 k 大数查找,建树查询即可

题解链接

2.P4587 [FJOI2016] 神秘数

设当前已考虑的数可组成的所有连续自然数的集合 S 的元素个数为 len

根据题目发现每一次加入一个数 x 时,若 x\gt len + 1,则它一定为神秘数,否则区间长度变为 len + x

所以只需要维护一段区间即可

该区间满足:

  1. 左端点为 len + 1
  2. 右端点为 len+x 可达到的最大范围

每一次对这样的权值区间进行区间内权值和 sum 的查询 (与传统线段树的区间查询类似)

将左端点变为右端点+1,右端点累加 sum 即可完成区间的转移

此题的值域较大且无法离散化,但是求区间权值和时建立完全树不必要,动态开点即可,舍去 build

code:

#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define gcu getchar_unlocked
using namespace std;

const int N = 100010, M = 18000010, len = 1e9;
typedef long long ll;
namespace fastread {
    void read(int &x)
    {
        x = 0; int f = 1; char c = gcu();
        while(c < '0' || c > '9') {if(c == '-') f = -1; c = gcu();}
        while(c >= '0' && c <= '9') x = (x << 3) + (x << 1) + (c ^ 48), c = gcu();
        x *= f;
    }
} using fastread::read;

int n, m;
int s[N]; 
struct Node{
    int l, r;
    int sum;
} tr[M];
int root[N], idx;

int update(int p, int l, int r, int x)
{
    int q = ++ idx;
    tr[q] = tr[p];
    if(l == r)
    {
        tr[q].sum += x;
        return q;
    }

    ll mid = (ll)(l + r) >> 1ll;
    if(x <= mid) tr[q].l = update(tr[p].l, l, mid, x);
    else tr[q].r = update(tr[p].r, mid + 1, r, x);
    tr[q].sum = tr[tr[q].l].sum + tr[tr[q].r].sum;
    return q;
}

int query(int p, int q, int l, int r, int suml, int sumr)
{
    if(suml <= l && sumr >= r) return tr[q].sum - tr[p].sum;

    ll mid = (ll)(l + r) >> 1ll;
    int sum = 0;
    if(suml <= mid) sum = query(tr[p].l, tr[q].l, l, mid, suml, sumr);
    if(sumr > mid) sum += query(tr[p].r, tr[q].r, mid + 1, r, suml, sumr);
    return sum;
}

int main()
{
    read(n);
    for(int i = 1; i <= n; i ++ ) read(s[i]);

    root[0] = ++ idx;
    for(int i = 1; i <= n; i ++ ) root[i] = update(root[i - 1], 1, len, s[i]);

    read(m);
    int l, r;
    while(m -- )
    {
        read(l), read(r);
        int suml = 0, sumr = 0;
        while(true)
        {
            int sum = query(root[l - 1], root[r], 1, len, suml + 1, sumr + 1);
            if(sum == 0) break;
            suml = sumr + 1, sumr += sum;
        }
        printf("%d\n", sumr + 1);
    }
    return 0;
}

拓展:动态主席树

思想:

如果对主席树进行单点修改,用静态主席树的时间复杂度为 \mathcal O(nlogn),显然不可行

但是,由静态主席树可知,主席树具有前缀和的特性,所以可以用树状数组代替本来的线性存储,即使用树状数组套主席树

此时修改的时间复杂度为 \mathcal O(log^2n)

例题:

1.P2617 Dynamic Rankings

查询时预先将要查找的树记下,节点下传时一起下传即可实现查找

仅仅是将静态主席树中数组的操作全部更换为树状数组的操作而已

可以参考这篇文章

code:

#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define gcu getchar_unlocked 
using namespace std;
const int N = 100010, M = 40000100, Log = 20;
typedef long long ll;

namespace fastread {
    void read(int &x)
    {
        x = 0; int f = 1; char c = gcu();
        while(c < '0' || c > '9') {if(c == '-') f = -1; c = gcu();}
        while(c >= '0' && c <= '9') x = (x << 3) + (x << 1) + (c ^ 48), c = gcu();
        x *= f;
    }
} using fastread::read;

int n, m;
int s[N], len;
vector<int> all;
char op[N][2];
struct QUERY{
    int a, b, kth;
} que[N];

namespace president_tree {
    struct Node{
        int l, r;
        int cnt;
    }tr[M];
    int root[N], idx;
    int x[Log], cntx, y[Log], cnty;

    int find(int w) {return lower_bound(all.begin(), all.end(), w) - all.begin();}

    int lowbit(int w) {return w & -w;}

    int build(int l, int r)
    {
        int p = ++ idx;
        if(l == r) return p;

        int mid = l + r >> 1;
        tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
        return p;
    }

    int update(int pre, int l, int r, int w, int v){
        int p = ++ idx;
        tr[p] = tr[pre];
        if(l == r)
        {
            tr[p].cnt += v;
            return p;
        }

        int mid = l + r >> 1;
        if(w <= mid) tr[p].l = update(tr[pre].l, l, mid, w, v);
        else tr[p].r = update(tr[pre].r, mid + 1, r, w, v);
        tr[p].cnt = tr[tr[p].l].cnt + tr[tr[p].r].cnt;
        return p;
    }

    void add(int w, int v)
    {
        int pos = find(s[w]);
        for(; w <= n; w += lowbit(w))
            root[w] = update(root[w], 0, len - 1, pos, v);
    }

    int query(int l, int r, int k)
    {
        if(l == r) return l;
        int mid = l + r >> 1;
        int sum = 0;

        for(int i = 1; i <= cntx; i ++ ) 
            sum -= tr[tr[x[i]].l].cnt;
        for(int i = 1; i <= cnty; i ++ ) 
            sum += tr[tr[y[i]].l].cnt;

        if(k <= sum)
        {
            for(int i = 1; i <= cntx; i ++ ) 
                x[i] = tr[x[i]].l;
            for(int i = 1; i <= cnty; i ++ ) 
                y[i] = tr[y[i]].l;
            return query(l, mid, k);
        }
        else 
        {
            for(int i = 1; i <= cntx; i ++ ) 
                x[i] = tr[x[i]].r;
            for(int i = 1; i <= cnty; i ++ ) 
                y[i] = tr[y[i]].r;
            return query(mid + 1, r, k - sum);
        }
    }
} using namespace president_tree;

int main()
{
    read(n), read(m);
    for(int i = 1; i <= n; i ++ ) 
    {
        read(s[i]);
        all.push_back(s[i]);
    }

    int a, b, c;
    for(int i = 1; i <= m; i ++ ) 
    {
        scanf("%s", op[i]);
        if(*op[i] == 'Q') 
        {
            read(a), read(b), read(c);
            que[i] = {a, b, c};
        }
        else 
        {
            read(a), read(b);
            all.push_back(b);
            que[i].a = a, que[i].b = b;
        }
    }

    sort(all.begin(), all.end());
    all.erase(unique(all.begin(), all.end()), all.end());
    len = all.size();

    root[0] = build(0, len - 1);

    for(int i = 1; i <= n; i ++ ) add(i, 1);

    for(int i = 1; i <= m; i ++ ) 
    {
        if(*op[i] == 'Q') 
        {
            cntx = cnty = 0;
            int l = que[i].a, r = que[i].b, k = que[i].kth;
            for(int j = l - 1; j; j -= lowbit(j))
                x[ ++ cntx] = root[j];
            for(int j = r; j; j -= lowbit(j))
                y[ ++ cnty] = root[j];
            printf("%d\n", all[query(0, len - 1, k)]);
        }
        else 
        {
            int a = que[i].a, val = que[i].b;
            add(a, -1);
            s[a] = val;
            add(a, 1);
        }
    }
    return 0;
}

2.P3380 [模板] 二逼平衡树 (树套树)

只是上一题加入了一些操作

所谓 查询 k 在区间 [l,\ r] 的排名即为将所有权值小于 k 的权值的个数累加即可

而查询前驱只需找到 k 的排名 pos 后输出 pos + 1 的值即可

查询后继是需要找到 k+1 的排名 pos 后输出 pos 的值即可 (可以避免出现 k 不在区间内的情况)

为了判断边界,将两边界值也插入树中即可

code:

#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N = 100010, M = 37000010, INF = 2147483647, Log = 25;
typedef long long ll;

namespace getchar_unlocked_fastread {
    const void read(int &x)
    {
        x = 0; int f = 1; char c = getchar_unlocked();
        while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar_unlocked();}
        while(c >= '0' && c <= '9') x = (x << 3) + (x << 1) + (c ^ 48), c = getchar_unlocked();
        x *= f;
    }
}using getchar_unlocked_fastread::read;

int n, m;
int s[N];
vector<int> all;
struct Node{
    int l, r;
    int cnt;
} tr[M];
struct QUERY{
    int op, a, b, c;
} que[N];
int root[N], idx;
int x[Log], cntx, y[Log], cnty, len;

int find(int w) {return lower_bound(all.begin(), all.end(), w) - all.begin();}

int lowbit(int w) {return w & -w;}

int build(int l, int r)
{
    int p = ++ idx;
    if(l == r) return p;

    int mid = l + r >> 1;
    tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
    return p;
}

int update(int p, int l, int r, int ver, int v)
{
    int q = ++ idx;
    tr[q] = tr[p];
    if(l == r)
    {
        tr[q].cnt += v;
        return q;
    }

    int mid = l + r >> 1;
    if(ver <= mid) tr[q].l = update(tr[p].l, l, mid, ver, v);
    else tr[q].r = update(tr[q].r, mid + 1, r, ver, v);
    tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
    return q;
}

void add(int ver, int v)
{
    int p = find(s[ver]);
    for(; ver <= n; ver += lowbit(ver))
        root[ver] = update(root[ver], 0, len - 1, p, v);
}

void move(int l, int r)
{
    cntx = cnty = 0;
    for(int i = l - 1; i; i -= lowbit(i))
        x[ ++ cntx] = root[i];
    for(int i = r; i; i -= lowbit(i))
        y[ ++ cnty] = root[i];
}

int query_rank(int l, int r, int k)
{
    if(l == k && r == k) return 1;

    int cnt = 0;
    for(int i = 1; i <= cntx; i ++ ) 
        cnt -= tr[tr[x[i]].l].cnt;
    for(int i = 1; i <= cnty; i ++ ) 
        cnt += tr[tr[y[i]].l].cnt;

    int mid = l + r >> 1;
    if(k <= mid)
    {
        for(int i = 1; i <= cntx; i ++ ) 
            x[i] = tr[x[i]].l;
        for(int i = 1; i <= cnty; i ++ ) 
            y[i] = tr[y[i]].l;

        return query_rank(l, mid, k);
    } 
    else
    {
        for(int i = 1; i <= cntx; i ++ ) 
            x[i] = tr[x[i]].r;
        for(int i = 1; i <= cnty; i ++ ) 
            y[i] = tr[y[i]].r;

        return cnt + query_rank(mid + 1, r, k);
    }
}

int query_val(int l, int r, int k)
{
    if(l == r) return l;

    int cnt = 0;
    for(int i = 1; i <= cntx; i ++ ) 
        cnt -= tr[tr[x[i]].l].cnt;
    for(int i = 1; i <= cnty; i ++ ) 
        cnt += tr[tr[y[i]].l].cnt;

    int mid = l + r >> 1;
    if(k <= cnt)
    {
        for(int i = 1; i <= cntx; i ++ ) 
            x[i] = tr[x[i]].l;
        for(int i = 1; i <= cnty; i ++ ) 
            y[i] = tr[y[i]].l;

        return query_val(l, mid, k);
    }
    else 
    {
        for(int i = 1; i <= cntx; i ++ ) 
            x[i] = tr[x[i]].r;
        for(int i = 1; i <= cnty; i ++ ) 
            y[i] = tr[y[i]].r;

        return query_val(mid + 1, r, k - cnt);
    }
}

int query_pre(int a, int b, int l, int r, int k)
{
    int rank = query_rank(0, len - 1, k);
    move(a, b);
    return all[query_val(0, len - 1, rank - 1)];
}

int query_suc(int a, int b, int l, int r, int k)
{
    int rank = query_rank(0, len - 1, k + 1); // 直接查找 k + 1 的排名可以避免判断值是否在当前序列中
    move(a, b);
    return all[query_val(0, len - 1, rank)];
}

int main()
{
    read(n), read(m);
    for(int i = 1; i <= n; i ++ )
    {
        read(s[i]);
        all.push_back(s[i]);
    }

    int op, a, b, c;
    for(int i = 1; i <= m; i ++ ) 
    {
        read(op);
        if(op == 1)
        {
            read(a), read(b), read(c);
            que[i] = {op, a, b, c};
            all.push_back(que[i].c);
        }
        else if(op == 2)
        {
            read(a), read(b), read(c);
            que[i] = {op, a, b, c};
        }
        else if(op == 3)
        {
            read(a), read(b);
            que[i].op = op, que[i].a = a, que[i].b = b;
            all.push_back(b);
        }
        else if(op == 4)
        {
            read(a), read(b), read(c);
            que[i] = {op, a, b, c};
            all.push_back(c);
        }
        else 
        {
            read(a), read(b), read(c);
            que[i] = {op, a, b, c};
            all.push_back(c);
        }
    } 
    all.push_back(-INF), all.push_back(INF);

    sort(all.begin(), all.end());
    all.erase(unique(all.begin(), all.end()), all.end());
    len = all.size();

    root[0] = build(0, len - 1);

    for(int i = 1; i <= n; i ++ ) add(i, 1);

    int l, r, k;
    for(int i = 1; i <= m; i ++ ) 
    {
        l = que[i].a, r = que[i].b, k = que[i].c;

        if(que[i].op != 3) move(l, r);

        if(que[i].op == 1) printf("%d\n", query_rank(0, len -1, find(k)));
        else if(que[i].op == 2) printf("%d\n", all[query_val(0, len - 1, k)]);
        else if(que[i].op == 3)
        {
            add(l, -1);
            s[l] = r;
            add(l, 1);
        }
        else if(que[i].op == 4) printf("%d\n", query_pre(l, r, 0, len - 1, find(k)));
        else if(que[i].op == 5) printf("%d\n", query_suc(l, r, 0, len - 1, find(k)));
    }

    return 0;
}