题解:P8776 [蓝桥杯 2022 省 A] 最长不下降子序列

· · 题解

P8776 最长不下降子序列

本题的树状数组或线段树做法:

​ 首先回顾一下 LIS 问题,也就是最长(严格)上升子序列,除了朴素的 O(n^2) 算法外,还有维护一个 ends 数组,其中 ends_i 表示 i + 1 长度的 LIS 结尾数字是多少(需要数组下标从0开始的强迫症,如果默认下标从1开始做很小修改就行)。对于原数组中的 a_j ,可以从 ends 数组中找到第一个 \geq a_j 的位置 idx,那么以 a_j 为结尾的 LIS 长度就是 idx+1。易知 ends 数组一定严格单调增,那么可以通过二分查找确定 idx,从而将 LIS 问题优化为 O(nlogn)

​ 然后来看本题。如果目前已经处理到了 a_j,那么如何修改 a_j 之前的 k 个连续数字,让它目前的不下降子序列最长呢?我们可以枚举每个 a_j,然后将它之前的 k 个数字变为 a_j,再计算整体的最长不下降子序列。因为如果存在一个最长不下降子序列,那么它一定可以包含这 k 个经过改变的数字,并且经过改变的数字可以是这个最长不下降子序列元素中的任意一个。通过枚举 a_j,我们就可以枚举它所有改变的情况,并从中选出最大值。

​ 那么这个问题可以简化为:

  1. 对于 j 位置元素 a_j,以 a_j 结尾的在下标 [0,j-k-1] 上的最长不下降子序列长度 len_1
  2. 通过枚举 a_j,得到 len_1+len_2+len_3 即为该位置能得到的最长不下降子序列。

​ 其中 len_3 可以通过从后往前的最长递增子序列得出,len_1 也可以在从前往后枚举 j 的过程中得到。总的时间复杂度为 O(n\log n)

int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    int n, k;
    cin >> n >> k;
    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    std::vector<int> r(n);
    [&]() -> void {
        std::vector<int> ends;
        for (int i = n - 1; i >= 0; i--) {
            int le = 0, ri = ends.size();
            while (le < ri) {
                int m = (le + ri) / 2;
                if (ends[m] >= a[i]) {
                    le = m + 1;
                } else {
                    ri = m;
                }
            }
            r[i] = le + 1;
            if (le == ends.size()) {
                ends.push_back(a[i]);
            } else {
                ends[le] = a[i];
            }
        }
    }();
    r.push_back(0);
    n++;
    a.push_back(INF);
    int ans = k;
    [&]() -> void {
        std::vector<int> ends;
        for (int i = k; i < n; i++) {
            int idx = std::upper_bound(ends.begin(), ends.end(), a[i - k]) - ends.begin();
            if (idx == ends.size()) {
                ends.push_back(a[i - k]);
            } else {
                ends[idx] = a[i - k];
            }
            idx = std::upper_bound(ends.begin(), ends.end(), a[i]) - ends.begin();
            ans = std::max(ans, i != n - 1 ? k + idx + r[i] : k + idx + r[i] - 1);
        }
    }();
    cout << ans << "\n";
    return 0;
}