题解:AT_abc324_g [ABC324G] Generate Arrays

· · 题解

前言

人生洛谷第一篇题解也写上紫题了!!!!

正文

用此题来对主席树的知识进行一个总结:

我们来回忆一下题目:对于给定一个从 1n 的原始序列(我们再次将其称之为祖先序列),现有如下两种操作:

题目回忆录

s_i 号序列中编号大于 x_i 的元素提出并将其组成一个新的 i 号序列,同时,将其在原序列的位置删除;

s_i 号序列中数值大于 x_i 的元素提出并将其组成一个新的 i 号序列,同时,将其在原序列的位置删除;

在这道题中,两种修改方式反映的是两种不同的区间,一是元素 i 在原始数组中的物理区间,二是元素 i 在原始数组中的值域区间;对于同时维护两种区间的数据结构,我们考虑使用可持久化线段树(即主席树)。

对于每一颗线段树,用来维护当前区间,在某一个值域范围内,数值的出现次数,我们考虑用数组 \min n[i]\max x[i] 分别来维护当前序列的最大值和最小值分别是多少,当触发类型 2 的修改时,我们只需要将原序列的最大值和最小值进行修改为 \min n[i]x[i],同时将新序列最大最小值进行维护,类型 2 就做完了。

minn[i] = max(minn[i], xi + 1);
maxx[si] = min(maxx[si], xi);

难点

我们考虑类型一的修改,由于主席树维护的是值域线段树,所以我们无法直接获取区间 [l[s_i], \text{pos}] 中值在 [\min n[s_i], \max x[s_i]] 范围内的元素个数是否等于 x_i,因此,我们考虑维护一个二分答案,通过主席树 \text{ask}(),暴力的去判断当前这个区间下标是否 \leq x_i,从而通过 \log^2 n 的时间复杂度找到逼近 x_i 的区间,为此我们需要维护 l[i], r[i] 作为当前区间的左右端点下标,在主席树中维护一个 \text{cnt} 表示当前区间内值的数量。

// 我们要找最大的位置 pos,使得区间内有效元素个数 <= xi
if (ask(rt[max(l[si] - 1, 0)], rt[midd], 0, n, minn[si], maxx[si]) <= xi) 
    ll = midd;  // 还可以往右扩展
else 
    rr = midd;  // 太多了,需要往左收缩

最后已知当前序列的值域范围和物理下标范围,套主席树模板就行了(别说这个你也不会敲)。

至此,此题结束。

AC Code

#include <bits/stdc++.h>
#define pir pair<int, int>
using namespace std;
const int N = 2e5 + 5, M = (N << 5) + 5;
int lc[M], rc[M], sum[M], l[N], r[N], maxx[N], minn[N], rt[N], a[N];
int n, tot, q, ti, si, xi, ll, rr, midd;

void add(int x, int &y, int l, int r, int d) {
    y = ++tot;
    lc[y] = lc[x];
    rc[y] = rc[x];
    sum[y] = sum[x] + 1;
    if (l == r) return;
    int mid = (l + r) >> 1;
    if (d <= mid) add(lc[x], lc[y], l, mid, d);
    else add(rc[x], rc[y], mid + 1, r, d);
}

int ask(int x, int y, int l, int r, int ql, int qr) {
    if (l >= ql && r <= qr) return sum[y] - sum[x];
    int mid = (l + r) >> 1;
    int val = 0;
    if (ql <= mid) val += ask(lc[x], lc[y], l, mid, ql, qr);
    if (qr > mid) val += ask(rc[x], rc[y], mid + 1, r, ql, qr);
    return val;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        add(rt[i - 1], rt[i], 1, n, a[i]);
    }
    scanf("%d", &q);
    l[0] = 1, r[0] = n, minn[0] = 1, maxx[0] = n;  // 原始序列
    for (int i = 1; i <= q; i++) {
        scanf("%d%d%d", &ti, &si, &xi);
        l[i] = l[si], r[i] = r[si], minn[i] = minn[si], maxx[i] = maxx[si];
        if (ti == 1) {
            ll = l[i] - 1, rr = r[i];
            while (ll + 1 < rr) {
                midd = (ll + rr) >> 1;
                if (ask(rt[max(l[si] - 1, 0)], rt[midd], 0, n, minn[si], maxx[si]) <= xi) 
                    ll = midd;
                else 
                    rr = midd;
            }
            if (ask(rt[max(l[si] - 1, 0)], rt[rr], 0, n, minn[si], maxx[si]) <= xi)
                l[i] = rr + 1, r[si] = rr;
            else
                l[i] = ll + 1, r[si] = ll;
        } else {
            minn[i] = max(minn[i], xi + 1);
            maxx[si] = min(maxx[si], xi);
        }
        printf("%d\n", ask(rt[max(l[i] - 1, 0)], rt[r[i]], 0, n, minn[i], maxx[i]));
    }
}