[ABC176F] Brave CHAIN(动态规划)

· · 题解

Atcoder F - Brave CHAIN

分析

因为我们每一次是必须要进行移除的,所以如果能得到尽量大的价值,我们就要争取

首先$dp$全部置为最小值 $dp[0][A[1]][A[2]] = dp[0][A[2]][A[1]] = 0 i$枚举从$0$到$n - 1

现在加入了新的a = A[3i], b = A[3i + 1], c = A[3i + 2]三张牌

  1. a == b == c dp[i + 1][x][y] = dp[i][x][y] + 1$,对于所有的$x$,$y

    我们可以尝试把第一维省去,另外开一个数字来记录这种情况,sum ++, 时间复杂度是O(1)

我们随机选则出前面选则的另外一个数字x

(1) a == b

dp[i + 1][c][x] = max(dp[i][a][x] + 1, dp[i][x][a] + 1)

(2) b == c

dp[i + 1][a][x] = max(dp[i][b][x] + 1, dp[i][x][b] + 1)

(3) a == c

dp[i + 1][b][x] = max(dp[i][a][x] + 1, dp[i][x][a] + 1)

注意:转移的时候为了避免省去第一维造成的影响,我开了另外一个数组f进行更新,然后再回去更新dp

这种转移的时间复杂度是O(n)的,因为我们要枚举x

3.a,b, c中所有的数字都是互相不相同的

(1)我们在前面选择两个相同的数字x来消除现在的一个数字

消除a dp[i + 1][b][c] = max(dp[i][a][a] + 1)

消除b dp[i + 1][a][c] = max(dp[i][b][b] + 1)

消除c dp[i + 1][a][b] = max(dp[i][c][c] + 1)

(2)我们在前面选则两个数字x, y,最后留下三个数字中的两个

dp[i + 1][a][b] = dp[i + 1][a][c] = dp[i + 1][b][c] = max(dp[i][x][y]) 

这个地方需要我们使用一个数组pre[x]来表示前面剩下数字x和另外某个数字的dp的最大值就可以O(1)转移 时间复杂度O(n)

所以我们还需要记录前面dp的最大值mx,直接O(1)更新就可以了

(3)我们在前面选则两个数字x,y,最后留下三个数字中的随便一个

留下a dp[i + 1][a][x] = max(dp[i][x][y], dp[i][y][x]) -> mx

留下b dp[i + 1][b][x] = max(dp[i][x][y], dp[i][y][x]) -> mx

留下c dp[i + 1][c][x] = max(dp[i][x][y], dp[i][y][x]) -> mx

(4)我们在前面选则两个数字xy,删除现在所有的三个数字

dp[i + 1][x][y] = max(dp[i][x][y])   

因为省去了第一维,所以不用更新;

计算的过程中会剩下最后一个数字,暴力枚举前面的情况,如果能消除就ans + 1,然后取最大的ans就可以了,最后答案是ans + sum

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 3005, inf = 0x3f3f3f3f;
int n, dp[N][N], A[N * 3], mx = 0, sum = 0, ans = 0, f[N][N], pre[N];
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n * 3; ++ i) scanf("%d", &A[i]);
    for (int i = 0; i <= n; ++ i)
    {
        for (int j = 0; j <= n; ++ j) dp[i][j] = -inf, f[i][j] = -inf, pre[i] = -inf;
    }
    dp[A[1]][A[2]] = dp[A[2]][A[1]] = 0;
    pre[A[1]] = 0, pre[A[2]] = 0;
    mx = 0;
    for (int i = 1; i <= n - 1; ++ i)
    {
        int a = A[i * 3], b = A[i * 3 + 1], c = A[i * 3 + 2];
        if(a == b && b == c) sum ++;
        else 
        {
            if(a == b) 
            {
                for (int x = 1; x <= n; ++ x) f[c][x] = max(f[c][x], max(dp[x][a] + 1, dp[a][x] + 1));
            }
            if(a == c)
            {
                for (int x = 1; x <= n; ++ x) f[b][x] = max(f[b][x], max(dp[x][a] + 1, dp[a][x] + 1));
            } 
            if(b == c)
            {
                for (int x = 1; x <= n; ++ x) f[a][x] = max(f[a][x], max(dp[x][b] + 1, dp[b][x] + 1));
            }
            for (int x = 1; x <= n; ++ x)
            {
                f[a][x] = max(f[a][x], pre[x]);
                f[b][x] = max(f[b][x], pre[x]);
                f[c][x] = max(f[c][x], pre[x]);
            }
            f[a][b] = max(f[a][b], mx);
            f[b][c] = max(f[b][c], mx);
            f[a][c] = max(f[a][c], mx);

            f[b][c] = max(f[b][c], dp[a][a] + 1);
            f[a][c] = max(f[a][c], dp[b][b] + 1);
            f[a][b] = max(f[a][b], dp[c][c] + 1); 
            for (int x = 1; x <= n; ++ x)
            {
                dp[a][x] = f[a][x], dp[b][x] = f[b][x], dp[c][x] = f[c][x]; 
                pre[x] = max(pre[x], max(f[a][x], max(f[b][x], f[c][x])));
                pre[a] = max(pre[a], f[a][x]);
                pre[b] = max(pre[b], f[b][x]);
                pre[c] = max(pre[c], f[c][x]);
                mx = max(mx, f[a][x]), mx = max(mx, f[b][x]), mx = max(mx, f[c][x]);
            }
            for (int x = 1; x <= n; ++ x) f[a][x] = -inf, f[b][x] = -inf, f[c][x] = -inf;   
        }   
    }
    for (int x = 1; x <= n; ++ x)
    {
        for (int y = 1; y <= n; ++ y)
        {
            if(x == y && y == A[n * 3]) ans = max(ans, max(dp[x][y], dp[y][x]) + 1);
            else ans = max(ans, dp[x][y]);
        }
    }
    printf("%d", ans + sum);
    return 0;
}