[题解] P15038 「chaynOI R2 T3」合并同类项

· · 题解

二操作是一操作的特殊情况,因此先让 y \leftarrow \min(y, x)

一个暴力的做法是记 f_{l, r} 表示删掉区间 [l, r] 的最小代价,转移就是切成两段,用两段的代价和与删掉当前区间的代价来转移。统计答案就是枚举剩下的位置,两侧都要删掉。时间复杂度 O(n^3)

注意到剩下的包都有 k \ge 400 的限制,感觉应该是需要用一下随机的性质了。

观察到我们上面的转移其实是只关心和为 0 的区间的,和不为 0 的区间完全可以一个一个点地删掉。因此一个猜测就是和为 0 的区间数量不会特别多。

事实上确实是这样的,期望应该是 \frac{n(n+1)}{2k} 的,感性理解一下就是每个区间的和都是 [0, k) 的随机映射,所以应该只有 1/k 的区间和为 0

这样的话我们可以将状态更改为 f_{l, r} 表示将 [l, r] 这个和为 0 的区间删到只剩一个数的最小代价。从小到大枚举每个和为 0 的区间,转移就是从这个区间内部选若干个不交的和为 0 的区间,用删掉它们的代价加上其余位置删到只剩一个的代价转移。

具体来说,转移需要用 g_{i, 0 / 1} 表示考虑了区间中的前 i 个位置,有没有确定最后保留哪个位置的最小代价来辅助转移,g 的转移就是每次删掉当前位置或者删掉一个和为 0 的区间。这样的话转移 gO(\frac{n^2}{k}) 的,总时间复杂度 O(\frac{n^4}{k})

一种优化方法是对于左端点相同的和为 0 的区间,按照右端点从小到大的顺序来计算它们的 f 值,这样每次都能继承上一个区间的 g 数组,因此对于所有左端点为 l 的区间转移的复杂度就是 O(\frac{n^2}{k}),总时间复杂度也就变成了 O(\frac{n^3}{k})

#include <bits/stdc++.h>

using i64 = long long;

constexpr i64 inf = 1E18;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    int n, k, x, y;
    std::cin >> n >> k >> x >> y;
    y = std::min(x, y);

    std::vector<int> a(n);
    for (int i = 0; i < n; i ++) {
        std::cin >> a[i];
    }

    std::vector<std::array<int, 2>> rs;
    std::vector<std::vector<int>> vec(n);
    for (int r = 1; r <= n; r ++) {
        int sum = 0;
        for (int l = r - 1; l >= 0; l --) {
            ((sum += a[l]) >= k) && (sum -= k);
            if (sum == 0 || (l == 0 && r == n)) {
                vec[l].push_back(rs.size());
                rs.push_back({l, r});
            }
        }
    }

    const int m = rs.size();
    std::vector<i64> f(m);
    std::vector<std::array<i64, 2>> g(n + 1);
    for (int i = 0, l, r = rs[0][1]; i < m; i ++) {
        if (i == 0 || rs[i][1] != r) {
            r = rs[i][1];
            l = r - 1;
            g[r][1] = inf;
            g[r][0] = 0;
        }
        while (l >= rs[i][0]) {
            g[l][1] = std::min(g[l + 1][0], g[l + 1][1] + x);
            g[l][0] = g[l + 1][0] + x;
            for (auto j : vec[l]) {
                if (rs[j][1] <= r && j != i) {
                    g[l][0] = std::min(g[l][0], f[j] + g[rs[j][1]][0] + y);
                    g[l][1] = std::min(g[l][1], f[j] + g[rs[j][1]][1] + y);
                    g[l][1] = std::min(g[l][1], f[j] + g[rs[j][1]][0]);
                }
            }
            -- l;
        }
        f[i] = g[rs[i][0]][1];
        l = rs[i][0];
    }

    std::cout << f.back() << '\n';

    return 0;
}