题解:P14558 [ROI 2013 Day2] 大规模预测

· · 题解

蠢蠢根号分治做法。

考虑对 v_i 的出现次数根号分治一下,对于 \le B 的部分,枚举一下左端点 l,在 [l, l + 2B - 1] 中枚举众数 v_i,设 [l, i]v_i 的出现次数为 t,那么 [l, \text{nxt}_{i} - 1]v_i 的出现次数也是 t。推导一下我们能取到多少个右端点:

2t > r - l + 1 \Rightarrow r \le l + 2t - 2.

故可以取 \min(\text{nxt}_i - 1, l + 2t - 2) - i + 1 个右端点。

对于 >B 的部分,我们先枚举这个值 x = v_i,然后转换一下问题:如果 x = v_j 那么 b_j = 1 否则 b_j = -1,求使得 b 的区间和 > 0 的区间数量。

这就等价于对于每个 j,找一下 <j 的前缀里小于当前前缀 p_j 的数量。由于相邻前缀和的差为 \pm 1,所以考虑开个桶维护,记作 c。思考从 p_{j - 1}p_{j} 会产生什么贡献:

单次可以 O(n) 求解。

这个题就做完了,时间复杂度 O(nB + \frac{n^2}{B})B150 再卡点小常能过。

/*
 * author: LostKeyToReach
 * created time: 2025-12-01 18:00:19
 */
#include <bits/stdc++.h>
// #define int long long
#define vt std::vector
#define vi vt<int>
#define eb emplace_back
using ll = long long;
using pii = std::pair<int, int>;
#define all(x) (x).begin(),(x).end()
#define sz(x) ((int)(x).size())
#define S std::cin.tie(0)->sync_with_stdio(0)
#define chkmax(x, y) x = std::max(x, y)
#define chkmin(x, y) x = std::min(x, y)
int fio = (S, 0);
constexpr int N = 1e6 + 5, base = 5e5, B = 150;
int n, k, v[N], cnt[N], h[N], nxt[N], lst[N], clr[N], tot;
int32_t main() {
    std::cin >> n >> k;
    for (int i = 1; i <= n; ++i)
        std::cin >> v[i], ++cnt[v[i]];
    for (int i = 1; i <= n; ++i) if (cnt[v[i]] > B)
        h[v[i]] = 1;
    for (int i = 1; i <= n; ++i) cnt[v[i]] = 0, lst[v[i]] = n + 1;
    for (int i = n; i >= 1; --i)
        nxt[i] = lst[v[i]], lst[v[i]] = i;
    long long ans = 0;
    for (int l = 1; l <= n; ++l) {
        for (int i = l, t; i <= std::min(n, l + 2 * B - 1); ++i) {
            if (h[v[i]]) continue;
            if ((t = ++cnt[v[i]]) == 1) clr[++tot] = v[i];
            ans += std::max(0ll, (ll)std::min(l + t * 2 - 2, nxt[i] - 1) - i + 1);
        }
        for (int i = 1; i <= tot; ++i) cnt[clr[i]] = 0;
        tot = 0;
    }
    for (int i = 1; i <= k; ++i) if (h[i]) {
        clr[tot = 1] = base;
        long long cur = 0; cnt[base] = 1;
        int sum = 0;
        for (int j = 1; j <= n; ++j) {
            if (v[j] == i) {
                cur += cnt[sum + base];
                sum += 1;
            } else {
                sum -= 1;
                cur -= cnt[sum + base];
            }
            ans += cur;
            if (!cnt[sum + base]++) clr[++tot] = sum + base;
        }   
        // std::cout << ans << "\n";
        for (int j = 1; j <= tot; ++j)
            cnt[clr[j]] = 0;
    }
    std::cout << ans << "\n";
}