Eternal (Hard ver.) 题解

· · 题解

书接上回,先总结一下简单版的做法,在一个块内,我们需要枚举断点,然后用一个前缀和一个后缀来更新状态。

然而很容易发现,最优状态一定在区间的端点上取得。故我们记以 i 为左端点、右端点的区间集合分别为 S_i,T_i,则处理 [l,r] 的状态时,只需枚举 S_l,T_r 中较小的一个集合,然后查询另一侧贡献的最大值即可。显然这个贡献是具有单调性的,因此可以二分。当然你也可以写一棵线段树,不过可能会被卡常。

根据三元环计数的结论,\sum \min(|S_i|,|T_i|)=O(|E| \sqrt{|E|}),故复杂度 O(n \sqrt{n} \log n)

你可能需要注意一些细节,包括但不限于重复区间、单点区间,可参考 std。

:::success[实现]{open}

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e5 + 10;
int n, maxc, ecnt, rcnt, pnt[MAXN], dp[MAXN];
vector <int> lidx[MAXN], ridx[MAXN], stra[MAXN], strb[MAXN];
struct Edge{
    int l, r, c, val, pre, suf;
}e[MAXN];
struct RawEdge{
    int l, r;
    bool operator < (const RawEdge& o) const{
        if (l != o.l){
            return l < o.l;
        }
        return r < o.r;
    }
    bool operator == (const RawEdge& o) const{
        return l == o.l && r == o.r;
    }
}re[MAXN];
int askl(int l, int rlim){
    int low = 0, high = (int)lidx[l].size() - 1, res = 0;
    while (low <= high){
        int mid = (low + high) >> 1;
        int id = lidx[l][mid];
        if (e[id].r <= rlim){
            res = e[id].pre;
            low = mid + 1;
        }
        else{
            high = mid - 1;
        }
    }
    return res;
}
int askr(int r, int llim){
    int low = 0, high = (int)ridx[r].size() - 1, res = 0;
    while (low <= high){
        int mid = (low + high) >> 1;
        int id = ridx[r][mid];
        if (e[id].l >= llim){
            res = e[id].suf;
            low = mid + 1;
        }
        else{
            high = mid - 1;
        }
    }
    return res;
}
void solve(){
    cin >> n;
    maxc = 0;
    rcnt = 0;
    for (int i = 1; i <= n; i++){
        int u, v;
        cin >> u >> v;
        maxc = max(maxc, v);
        if (u == v){
            pnt[u]++;
        }
        else {
            rcnt++;
            re[rcnt].l = u;
            re[rcnt].r = v;
        }
    }
    sort(re + 1, re + 1 + rcnt);
    ecnt = 0;
    for (int i = 1; i <= rcnt; ){
        int j = i;
        while (j <= rcnt && re[j] == re[i]){
            j++;
        }
        ecnt++;
        e[ecnt].l = re[i].l;
        e[ecnt].r = re[i].r;
        e[ecnt].c = j - i;
        e[ecnt].val = e[ecnt].pre = e[ecnt].suf = 0;
        i = j;
    }
    for (int i = 1; i <= maxc; i++){
        lidx[i].clear();
        ridx[i].clear();
        stra[i].clear();
        strb[i].clear();
    }
    for (int i = 1; i <= ecnt; i++){
        lidx[e[i].l].push_back(i);
        ridx[e[i].r].push_back(i);
    }
    for (int i = 1; i <= maxc; i++){
        if (!lidx[i].empty()){
            sort(lidx[i].begin(), lidx[i].end(), [](int a, int b){
                return e[a].r < e[b].r;
            });
            int sum = 0;
            for (int id : lidx[i]){
                sum += e[id].c;
                e[id].pre = sum;
            }
        }
        if (!ridx[i].empty()){
            sort(ridx[i].begin(), ridx[i].end(), [](int a, int b){
                return e[a].l > e[b].l;
            });
            int sum = 0;
            for (int id : ridx[i]){
                sum += e[id].c;
                e[id].suf = sum;
            }
        }
    }
    for (int i = 1; i <= ecnt; i++){
        int u = e[i].l, v = e[i].r;
        if (lidx[u].size() > ridx[v].size() || (lidx[u].size() == ridx[v].size() && u > v)){
            stra[u].push_back(i);
        }
        else{
            strb[v].push_back(i);
        }
    }
    for (int l = 1; l <= maxc; l++){
        if (stra[l].empty()){
            continue;
        }
        for (int id : stra[l]){
            int r = e[id].r;
            int vl = askl(l, r - 1), vr = askr(r, l + 1), maxk = 0;
            if (l + 1 <= r - 1){
                maxk = vl + askr(r, r - 1);
                for (int id2 : ridx[r]){
                    int k = e[id2].l;
                    if (l + 1 <= k && k <= r - 1){
                        maxk = max(maxk, askl(l, k) + e[id2].suf);
                    }
                }
            }
            e[id].val = e[id].c + max({vl, vr, maxk});
        }
    }
    for (int r = 1; r <= maxc; r++){
        if (strb[r].empty()){
            continue;
        }
        for (int id : strb[r]){
            int l = e[id].l;
            int vl = askl(l, r - 1), vr = askr(r, l + 1), maxk = 0;
            if (l + 1 <= r - 1){
                maxk = askl(l, l + 1) + vr;
                for (int id2 : lidx[l]){
                    int k = e[id2].r;
                    if (l + 1 <= k && k <= r - 1){
                        maxk = max(maxk, e[id2].pre + askr(r, k));
                    }
                }
            }
            e[id].val = e[id].c + max({vl, vr, maxk});
        }
    }
    dp[0] = 0;
    for (int i = 1; i <= maxc; i++){
        dp[i] = dp[i - 1] + pnt[i];
        for (int id : ridx[i]){
            int l = e[id].l;
            dp[i] = max(dp[i], dp[l] + e[id].val + pnt[i]);
        }
    }
    cout << dp[maxc] << "\n";
    for (int i = 1; i <= maxc; i++){
        pnt[i] = 0;
        dp[i] = 0;
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int t;
    cin >> t;
    while (t--){
        solve();
    }
    return 0;
}

:::