#Solution for AT_abc176_f

· · 题解

AT_abc176_f 题解

这是一道非常有趣的题目,可以帮助深化对于 dp 的理解。

题意

给定 n 与一个长度为 3n 的数组。你需要每次从前五个中选出三个删除,若删除的三个元素相同则收益加一。求删空的最大收益。

解题

Part 1

首先,显然我们有一个朴素 dp,因为我们对于后面的操作只关心前面剩下的数,所以可以记录每次保留了哪两个数。转移形如:

F_{i, x, y}+[x=y=a] \to F_{i+1, b, c}\\ F_{i, x, y}+[x=a=b] \to F_{i+1, y, c}\\ F_{i,x,y}+[a=b=c] \to F_{i+1, x, y}

分别对应新增的三个数中,保留两个、保留一个、不保留。

直接转移容易做到 \mathcal{O}(n^{3})

Part 2

考虑我们每一层 F_{i,?,?} 的转移,仅依赖于上一层的的状态。因此考虑把 F 的后两位看作一个 n \times n 大小矩阵,转移视作映射 \mathcal{f},则为:

F_i = f(F_{i-1})

发现 \mathcal{f} 看上去不是一个矩阵乘法的形式。但是没关系,我们考察 \mathcal{f} 的映射关系。我们把 F_{i-1} 中的元素 (x,y)F_i 中它可达的元素 {a, b} 连一条边,观察一下这个图的长相。显然有 \mathcal{O}(n^2) 数量的边。首先尝试缩减状态数,但是由于答案要求的是最后一个矩阵的最大值,且状态转移间每个元素都有连边,因此每个元素都是有用的,无法缩减。故考虑优化转移。考察这个映射的边,首先观察到有大量映射是由元素映射到本身的,所以猜测有效的映射只有 \mathcal{O(n)} 个。接下来给出证明。

边权为 1 的边:

\begin{cases} 1+(x,y) \to (x,y)&&{a=b=c}\\ 1+(x,y) \to (y, c)&&{a=b=x}\\ 1+(x,y) \to (b,c)&&(x=y=a) \end{cases}

其中第一类,等价于将矩阵全局加一。可使用全局懒标记,将在稍后证明正确性。第二、三类,只有 \mathcal{O}(n) 条边,暴力转移即可。

边权为 0 的边:

\begin{cases} (x,y) \to (x,y)\\ (x,y) \to (y, c)\\ (x,y) \to (b,c) \end{cases}

其中第一类为恒等映射,不用处理。第二类,枚举 y,则相当于查询矩阵第 y 列的最大值。第三类,相当于查询整个矩阵的最大值。这些查询都可以在转移途中顺便记录

综上,我们可以在 \mathcal{O}(n) 的时间内,完成一次映射。做 n 次映射,故最终时间复杂度为 \mathcal{O}(n^2)

全局懒标记的正确性

我们的转移总是形如 (x,y) \to(a,b)。若在某一次操作中进行了矩形整体加一,则随后的转移:

1+(x,y) \to (a,b)

可以等价于两个映射:

(x,y) \to (a,b)\\ 1+(a,b) \to (a,b)

可验证二者满足交换律。故可将所有的第二类映射放在最后统一加上。

有些读者可能会注意到,我们会出现一些 "不合法" 的转移。比如删去了三个相同的数,却没有将收益加一。但是,由于我们的转移的确映射到了每个状态,并且我们要求的是取 max,故最终答案不变。

代码

#include <bits/stdc++.h>

using namespace std;

const int MAXN = 2e3 + 10;

template<class T>
void chkmax(T &a, const T &b) {
    if(a < b) a = b;
}

int n, n3;
int mxR[MAXN], mxv, f[MAXN][MAXN];
int h[3*MAXN];

struct U {
    int x, y, v;
};
vector<U> ud;

void add(int x, int y, int v) {
    ud.push_back(U{x, y, v});
}

int laz;

void apply() {
    for(auto &U : ud) chkmax(f[U.x][U.y], U.v), chkmax(mxR[U.x], f[U.x][U.y]), chkmax(mxv, f[U.x][U.y]);
}
template<class T>
bool eq(const T &a, const T&b, const T&c) {
    return a == b && b == c;
}

void tran(int a, int b, int c) {
    for(int y = 1; y <= n; y++) {
        add(y, c, mxR[y]);
        add(c, y, mxR[y]);
    }
    if(a == b) for(int y = 1; y <= n; y++) add(y, c, 1 + f[a][y]), add(c, y, 1 + f[a][y]);
    add(b, c, mxv), add(c, b, mxv);
    add(b, c, 1 + f[a][a]), add(c, b, 1 + f[a][a]);
}

void solve() {
    scanf("%d", &n), n3 = n * 3;
    for(int i = 1; i <= n3; i++) scanf("%d", h + i);
    memset(f, -0x3f, sizeof f);
    memset(mxR, -0x3f, sizeof mxR);
    f[h[1]][h[2]] = f[h[2]][h[1]] = 0;
    mxv = 0;
    mxR[h[1]] = mxR[h[2]] = 0;
    for(int i = 3; i < n3; i+=3) {
        int a = h[i], b = h[i+1], c = h[i+2];
        if(eq(a, b, c)) ++laz;
        else {
            for(int j = 0; j < 3; j++) for(int k = 0; k < 3; k++) for(int s = 0; s < 3; s++) {
                if(j != k && j != s && k != s) tran(h[i+j], h[i+k], h[i + s]);
            }
        }
        apply();
        ud.clear();
    }
    printf("%d", max(f[h[n3]][h[n3]]+1, mxv) + laz);
}

int main() {
    solve();
}

祝大家学习愉快!