题解:P3886 [JLOI2009] 神秘的生物

· · 题解

本文节选自插头 DP 学习笔记。

2.4 P3886 [JLOI2009] 神秘的生物

题意转化为:将方阵中的若干个点染黑,使得黑点的权值和最大,并且黑点形成一个连通块。

考虑插头 DP。注意本题中分界线设置为方格而不是方格之间的直线,因此分界线的长度即为 n 的值。并在 DP 的维度中记录分界线上所有黑点所处的连通块编号。

因为列数最多为 9,所以分界线上最多有 5 个连通块,再加上一个白点的状态,需要用 6 进制数表示。为了卡常,可以使用 8 进制数存储。进一步优化,可以使用连通块的最小表示法,即从左到右依次将连通块赋予编号。

转移的时候需要注意:新加入的一个点如果是白点,加入后黑点的连通块个数不能减少,如果减少了就意味着加入了一个孤立的连通分量,这不符合题目的要求。

时间复杂度 O(n^3|S|)。其中 |S| 表示本质不同的连通块最小表示法个数。

#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 10, M = 1e6, B = 999983, inf = 0x3f3f3f3f;
int n, a[N][N], bit[N], bas[N], ans = -inf;

struct HashTable {
    int h[M], id[M], val[M], idx, ne[M];
    void clear() {
        memset(h, 0, sizeof(h));
        idx = 0;
    }
    void insert(int st, int v) {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                val[i] = max(val[i], v);
                return;
            }
        }

        ne[++idx] = h[st % B];
        h[st % B] = idx;
        id[idx] = st;
        val[idx] = v;
    }
} dp[2];

struct DSU {
    int fa[N];
    void init() {
        for(int i = 0; i < N; i++) fa[i] = i;
    }
    int findf(int x) {
        if(fa[x] != x) fa[x] = findf(fa[x]);
        return fa[x];
    }
    void combine(int x, int y) {
        int fx = findf(x), fy = findf(y);
        if(fx < fy) swap(fx, fy);
        fa[fx] = fy;
    }
} dsu;

int Trans(int st, int pos) {
    int res = 0; dsu.init();
    int val[N], nid[N], pre[N], cnt = 0;
    memset(pre, -1, sizeof(pre));
    val[0] = 0;
    pre[0] = 0;
    nid[0] = 0;

    for(int i = 1; i <= n; i++) {
        val[i] = (st >> bit[i]) & 7;
        nid[i] = -1;
        if(pre[val[i]] >= 0 && (pos != i || val[i]))
            dsu.combine(i, pre[val[i]]);
        if(pos != i || val[i])
            pre[val[i]] = i;
        if(i + 1 == pos && val[i]) dsu.combine(i, i + 1);
        if(i > 1 && val[i] && val[i - 1]) dsu.combine(i, i - 1);
    }

    for(int i = 1; i <= n; i++) {
        int f = dsu.findf(i);
        if(nid[f] < 0) nid[f] = ++cnt;
        res += bas[i] * nid[f];
    }

    return res;
}

bool check(int st) {
    for(int i = 1; i <= n; i++)
        if(((st / bas[i]) & 7) > 1)
            return 0;
    return 1;
}

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Init
    for(int i = 0; i < N; i++) {
        bit[i] = i * 3;
        bas[i] = (1 << bit[i]);
    }

    // Input
    cin >> n;
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            cin >> a[i][j];

    // DP
    int now = 0, pre = 1;

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= n; j++) {
            swap(now, pre);
            dp[now].clear();
            dp[now].insert(bas[j], a[i][j]);

            ans = max(ans, a[i][j]);
            for(int k = 1; k <= dp[pre].idx; k++) {
                // Combine
                int vst = Trans(dp[pre].id[k], j);
                dp[now].insert(vst, dp[pre].val[k] + a[i][j]);
                if(check(vst)) ans = max(ans, dp[pre].val[k] + a[i][j]);

                // Not Combine
                vst = dp[pre].id[k];
                int tmp = ((vst / bas[j]) & 7);
                bool flag = (tmp == 0); 
                for(int p = 1; p <= n; p++) {
                    if(p == j) continue;
                    if(((vst / bas[p]) & 7) == tmp) flag = 1;
                }
                if(flag) dp[now].insert(Trans(vst - bas[j] * ((vst / bas[j]) & 7), -1), dp[pre].val[k]);
            }
        }
    }

    cout << ans;
    return 0;
}