题解:P12485 [集训队互测 2024] PM 大师

· · 题解

写这个题跟用户名没啥关系。

a_i=0 的位置为 c_1,\cdots,c_mpre_ia_{1\sim i}0 的个数。可以发现 b_{c_i} 是严格单调递增的,设其中的数值构成的集合为 S

从值域考虑,不难发现对于非 0a_i=v,我们只关心其第一次出现的位置 p_v,设 d_v=pre_{p_v}(若没有出现 vd_v=m)。按值域从小到大枚举 v=1\sim n,设 s_v1\sim v-1\notin S 的数的个数,则

v\notin S\Leftrightarrow v-d_v-s_v>0

查询 a_y 时,若 a_y=0,则相当于查询第 pre_y 小的 S 中的数。在值域上开一棵线段树,区间内维护 v-d_v-s_v\leq 0 的数的个数,线段树二分即可得到结果。

考虑如何处理修改。每次将 a_x 修改为 a_x' 后,d_{a_x}d_{a_x'} 都可能变化。这相当于做两次对 v-d_v-s_v 的单点修改。考察某次单点修改带来的变化,若原先 v-d_v-s_v>0,修改后 v-d_v-s_v\leq 0,则 v 插入 S 中,v+1\sim n 中的值的 s_v 都会减 1。注意到,这可能导致 v+1\sim n 中的某些值的 v-d_v-s_v\leq 0 变成 >0,此时我们取最小的这样的数 v_{\min},则 v_{\min} 将从 S 中删除,删除后更大的值就恢复正常了。另一种情况也是对称的。

同样用线段树维护,区间内维护 \in S 的值的最大值及其位置,和 \notin S 的值的最小值及其位置,还有加法标记即可。p_v 可以用 set 或者懒删除的堆维护。时间复杂度为 \mathcal{O}((n+q)\log{n})

:::success[主要代码]

int n, q, a[MAXN], pre[MAXN], val[MAXN];
bool vis[MAXN];
priority_queue<int, vector<int>, greater<>> Q[MAXN];

struct SegTree {
#define ls(p) (p << 1)
#define rs(p) (p << 1 | 1)
    struct Node {
        int mn, mnp, mx, mxp, cnt;
        Node &operator+=(const Node &rhs) {
            if (rhs.mn < mn) {
                mn = rhs.mn;
                mnp = rhs.mnp;
            }
            if (rhs.mx > mx) {
                mx = rhs.mx;
                mxp = rhs.mxp;
            }
            cnt += rhs.cnt;
            return *this;
        }
        friend Node operator+(Node lhs, const Node &rhs) {
            return lhs += rhs;
        }
    } nd[MAXN << 2];
    int tg[MAXN << 2];

    void build(int p, int l, int r) {
        if (l == r) {
            nd[p] = {inf, 0, -inf, 0, 0};
            if (val[l] > 0) {
                vis[l] = false;
                nd[p].mn = val[l];
                nd[p].mnp = l;
            } else {
                vis[l] = true;
                nd[p].mx = val[l];
                nd[p].mxp = l;
                nd[p].cnt = 1;
            }
            return;
        }
        int mid = l + r >> 1;
        build(ls(p), l, mid);
        build(rs(p), mid + 1, r);
        pushUp(p);
    }
    void pushUp(int p) {
        nd[p] = nd[ls(p)] + nd[rs(p)];
    }
    void applyAdd(int p, int v) {
        tg[p] += v;
        if (nd[p].mn != inf) nd[p].mn += v;
        if (nd[p].mx != -inf) nd[p].mx += v;
    }
    void pushDown(int p) {
        if (!tg[p]) return;
        applyAdd(ls(p), tg[p]);
        applyAdd(rs(p), tg[p]);
        tg[p] = 0;
    }
    int query(int p, int l, int r, int x) {
        if (l == r) return vis[l] ? nd[p].mx : nd[p].mn;
        pushDown(p);
        int mid = l + r >> 1;
        return x <= mid ? query(ls(p), l, mid, x) : query(rs(p), mid + 1, r, x);
    }
    void upd(int p, int l, int r, int x) {
        if (l == r) {
            if (nd[p].mn <= 0) {
                nd[p].mx = nd[p].mn;
                nd[p].mn = inf;
                swap(nd[p].mnp, nd[p].mxp);
            } else if (nd[p].mx > 0) {
                nd[p].mn = nd[p].mx;
                nd[p].mx = -inf;
                swap(nd[p].mnp, nd[p].mxp);
            }
            nd[p].cnt = vis[l];
            return;
        }
        pushDown(p);
        int mid = l + r >> 1;
        if (x <= mid) upd(ls(p), l, mid, x);
        else upd(rs(p), mid + 1, r, x);
        pushUp(p);
    }
    void add(int p, int l, int r, int x, int y, int v) {
        if (x <= l && y >= r) {
            applyAdd(p, v);
            return;
        }
        pushDown(p);
        int mid = l + r >> 1;
        if (x <= mid) add(ls(p), l, mid, x, y, v);
        if (y > mid) add(rs(p), mid + 1, r, x, y, v);
        pushUp(p);
    }
    int find(int p, int l, int r, int k) {
        if (l == r) return l;
        pushDown(p);
        int mid = l + r >> 1;
        return k <= nd[ls(p)].cnt ? find(ls(p), l, mid, k) : find(rs(p), mid + 1, r, k - nd[ls(p)].cnt);
    }
#undef ls
#undef rs
} sgt;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> q;
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
        pre[i] = pre[i - 1] + (a[i] == 0);
    }
    int m = pre[n];
    for (int v = 1, cnt = 0; v <= n; ++v) {
        val[v] = v - m - cnt;
        if (val[v] > 0) ++cnt;
    }
    sgt.build(1, 1, n);
    while (q--) {
        int x, k, y;
        cin >> x >> k >> y;

        auto getD = [&](int v) {
            if (v <= 0) return -1;
            while (!Q[v].empty() && a[Q[v].top()] != v) Q[v].pop();
            return Q[v].empty() ? m : pre[Q[v].top()];
        };
        auto upd = [&](int v, int dt) {
            int prvVal = sgt.query(1, 1, n, v), nxtVal = prvVal + dt;
            sgt.add(1, 1, n, v, v, dt);
            if (prvVal > 0 && nxtVal <= 0) {
                vis[v] = true;
                sgt.upd(1, 1, n, v);
                if (v < n) {
                    sgt.add(1, 1, n, v + 1, n, 1);
                    if (sgt.nd[1].mx > 0) {
                        int p = sgt.nd[1].mxp;
                        vis[p] = false;
                        sgt.upd(1, 1, n, p);
                        if (p < n) sgt.add(1, 1, n, p + 1, n, -1);
                    }
                }
            } else if (prvVal <= 0 && nxtVal > 0) {
                vis[v] = false;
                sgt.upd(1, 1, n, v);
                if (v < n) {
                    sgt.add(1, 1, n, v + 1, n, -1);
                    if (sgt.nd[1].mn <= 0) {
                        int p = sgt.nd[1].mnp;
                        vis[p] = true;
                        sgt.upd(1, 1, n, p);
                        if (p < n) sgt.add(1, 1, n, p + 1, n, 1);
                    }
                }
            }
        };

        int tmp = a[x], prv1 = getD(tmp), prv2 = getD(k);
        a[x] = k;
        if (k > 0) Q[k].emplace(x);
        int nxt1 = getD(tmp), nxt2 = getD(k);
        if (tmp > 0) upd(tmp, prv1 - nxt1);
        if (k > 0) upd(k, prv2 - nxt2);
        cout << (a[y] ? a[y] : sgt.find(1, 1, n, pre[y])) << '\n';
    }
    return 0;
}

:::