区间 assign max(checkmax) 区间和详解

· · 算法·理论

前言

之所以叫“详解”,是因为我打算阐述想出这个东西的全思维过程以及想法出现的“原因”。

这个做法是我自己口胡的,但是可能与经典做法重复 (一定重复)

问题简述

给定一个序列。你需要支持两种操作:

  1. 给定一个区间和一个数 x,将该区间内的所有小于 x 的元素赋为 x
  2. 求一个给定区间内所有元素的和

序列长度(以下设为 n)和操作个数(以下设为 q)不超过 10^5

以下设值域为 V

开始

众所周知,区间操作一般可以用以下的东西维护:

  1. 线段树
  2. 分块

(因为我们听说)这个操作用线段树难以维护,所以用分块。(具体见 "Why can't I use segment tree?" 一节)

首先,肯定先分块。

然后,考虑对于每个块维护。考虑打一个标记代表所有针对该块的操作1中最大的 x,再维护块内和。

然而我们发现,已知这些信息并不能求出当前的区间和。于是考虑维护更多的信息。

维护哪些信息

发现操作1是一个与大小关系相关的东西,也就是说与值域相关,这启发我们考虑维护与值域相关的信息。

如果你像我一样学平衡树学魔怔了,我们最先想到每个块内维护一个 Splay,每个节点上额外维护一个当前子树内的和。对于一整个块的查询,就只需要查出小于当前标记的数的个数与和,从总和中减去这部分和,然后加上个数与标记的乘积。

避免噩梦般的 Splay

通过一点简单的分析,我们发现权值树状数组也能达到同样的目的。于是我们就成功避免了写平衡树

但是,经过更进一步的思考,我们发现其实根本没有对块内的权值树状数组进行修改的必要!因为我们的修改操作只需要修改标记,而查询时只要有标记和最初序列的值域信息(具体来说,小于标记的数的个数和这些数的和)就足够了。

于是,直接换用前缀和。现在有了一个时间复杂度 O(n+(q+V)\sqrt{n}) 的做法。

但是这个做法的时空复杂度是和值域相关的,而且要求区间和,还不能离散化,如果 a_i\leq10^9 就做不了了。

脱离值域的限制

想想怎么脱离值域的限制。

想想权值线段树和平衡树的差别。平衡树内部仅有插入的元素节点,而权值线段树为了维护结构还有一大堆并不是插入元素的内部节点。

我们的前缀和不也如此吗?有很多根本就没有在块内出现过的值的前缀信息也被存下来了!

我们可以把这些东西都去掉。具体的,对于每个块,将所有块内元素从小到大排序,然后维护前缀和。查询时二分就可以知道小于标记的元素的个数和这些元素的和。

这样,我们就做完了。

细节

对散块的修改会打乱块内元素的大小顺序。所以对散块的修改必须暴力下传标记。

复杂度分析

查询是散块 O(\rm{len}),整块 O(\rm{len}\times\log\rm{len}) 的,修改的散块和整块复杂度则刚好相反。所以只能取块长 len=O(\sqrt{n}),单次操作时间复杂度 O(\sqrt n\log n)

Why can't I use segment tree?

其实也可以用 :P。同样地,对每个节点维护标记、小于当前的标记的值的个数与和。

预处理时间复杂度 O(n\log^2 n) \text{ 或 } O(n\log n) (如果合并子节点信息时使用双指针而不是重新排序可以达到后面那个复杂度),查询时间复杂度 O(q\log^2 n),空间复杂度 O(n\log n)

提交入口

https://www.luogu.com.cn/problem/U530370

相似题

教主的魔法

BZOJ4695 最佳女选手\ 以及我的题解:https://www.luogu.com.cn/article/tn0wdjz7

代码

慢慢更。

已经写完了,但是还没有验正确性。

#include <bits/stdc++.h>

using namespace std;
int n, q, len, num;
array<int, 100005> a;
array<int, 1005> sum, max_tag;
array<vector<int>, 1005> block;
array<vector<long long>, 1005> prefix_sum;

inline int belong(int pos) { return (pos - 1) / len + 1; }

long long query(int l, int r) {
    int sid = belong(l), eid = belong(r);
    long long res = 0;

    if (sid == eid) {
        for (int i = l; i <= r; ++i) {
            if (a[i] > max_tag[sid]) {
                res += a[i];
            } else {
                res += max_tag[sid];
            }
        }
    } else {
        for (int i = l; belong(i) == sid; ++i) {
            if (a[i] > max_tag[sid]) {
                res += a[i];
            } else {
                res += max_tag[sid];
            }
        }

        for (int i = r; belong(i) == eid; --i) {
            if (a[i] > max_tag[eid]) {
                res += a[i];
            } else {
                res += max_tag[eid];
            }
        }

        for (int id = sid + 1; id < eid; ++id) {
            int cnt = lower_bound(block[id].begin(), block[id].end(), max_tag[id]) - block[id].begin();

            if (cnt > 0) {
                long long x = sum[cnt];

                res += sum[id] - x + 1ll * cnt * max_tag[id];
            } else {
                res += sum[id];
            }
        }
    }

    return res;
}

void check_max(int l, int r, int x) {
    int sid = belong(l), eid = belong(r);

    if (sid == eid) {
        if (x > max_tag[sid]) {
            for (int i = l; i <= r; ++i) {
                a[i] = max(a[i], x);
            }

            for (int i = (sid - 1) * len + 1; i <= min(sid * len, n); ++i) {
                a[i] = max(a[i], max_tag[sid]);
            }

            block[sid].clear();
            prefix_sum[sid].clear();

            for (int i = (sid - 1) * len + 1; i <= min(sid * len, n); ++i) {
                block[sid].push_back(a[i]);
            }

            sort(block[sid].begin(), block[sid].end());

            prefix_sum[sid].push_back(0);
            for (int i = 0; i < block[sid].size(); ++i) {
                prefix_sum[sid].push_back(prefix_sum[sid].back() + block[sid][i]);
            }
        }
    } else {
        if (x > max_tag[sid]) {
            for (int i = l; belong(i) == sid; ++i) {
                a[i] = max(a[i], x);
            }

            for (int i = l - 1; belong(i) == sid; --i) {
                a[i] = max(a[i], max_tag[sid]);
            }

            block[sid].clear();
            prefix_sum[sid].clear();

            for (int i = (sid - 1) * len + 1; i <= min(sid * len, n); ++i) {
                block[sid].push_back(a[i]);
            }

            sort(block[sid].begin(), block[sid].end());

            prefix_sum[sid].push_back(0);
            for (int i = 0; i < block[sid].size(); ++i) {
                prefix_sum[sid].push_back(prefix_sum[sid].back() + block[sid][i]);
            }
        }

        if (x > max_tag[eid]) {
            for (int i = r; belong(i) == eid; --i) {
                a[i] = max(a[i], x);
            }

            for (int i = r + 1; belong(i) == eid; ++i) {
                a[i] = max(a[i], max_tag[eid]);
            }

            block[eid].clear();
            prefix_sum[eid].clear();

            for (int i = (eid - 1) * len + 1; i <= min(eid * len, n); ++i) {
                block[eid].push_back(a[i]);
            }

            sort(block[eid].begin(), block[eid].end());

            prefix_sum[eid].push_back(0);
            for (int i = 0; i < block[eid].size(); ++i) {
                prefix_sum[eid].push_back(prefix_sum[eid].back() + block[eid][i]);
            }
        }

        for (int id = sid + 1; id < eid; ++id) {
            max_tag[id] = max(max_tag[id], x);
        }
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> q;
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
    }

    len = floor(sqrt(n));
    num = (n + len - 1) / len;

    for (int i = 1; i <= n; ++i) {
        sum[belong(i)] += a[i];
        block[belong(i)].push_back(a[i]);
    }

    for (int i = 1; i <= num; ++i) {
        max_tag[i] = -1e9;
        sort(block[i].begin(), block[i].end());

        prefix_sum[i].push_back(0);
        for (int j = 0; j < block[i].size(); ++j) {
            prefix_sum[i].push_back(block[i][j] + prefix_sum[i].back());
        }
    }

    for (int _ = 0; _ < q; ++_) {
        int op, l, r, x;
        cin >> op >> l >> r;

        if (op == 1) {
            cin >> x;
            check_max(l, r, x);
        } else {
            cout << query(l, r) << "\n";
        }
    }
    return 0;
}