题解:P3272 [SCOI2011] 地板

· · 题解

本文节选自插头 DP 学习笔记。

2.5 P3272 [SCOI2011] 地板

数据范围很小,且 L 形的性质很好,考虑插头 DP。先确定分界线的形态,显然直接套用一般的插头 DP 分界线即可。第二步考虑如何在分界线上记录信息,对于 L 形的地板,大致可以分为三类插头:空插头、后面必须收尾的插头、后面必须有拐点的插头,这三类插头分别记为 0, 1, 2。因此信息转化为一个三进制数,为了优化常数,转化为四进制数加上哈希表存储 DP 状态。

转移的时候对分界线拐点处的右插头、下插头状态进行分类讨论。

时间复杂度 O(nm2^{\min\{n, m\}})。注意一个细节,一开始如果 n < m 要将矩阵旋转一下再做插头 DP。

#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int mod = 20110520;
const int N = 105, M = 1e6, B = 999983;
int n, m, a[N][N];

void add(int &x, int val) {
    x += val;
    if(x >= mod) x -= mod;
}

struct HashTable{
    int h[M], id[M], val[M], ne[M], idx;

    void clear() {
        memset(h, 0, sizeof(h));
        idx = 0;
    }

    void insert(int st, int v) {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                add(val[i], v);
                return;
            }
        }

        ne[++idx] = h[st % B];
        h[st % B] = idx;
        id[idx] = st;
        val[idx] = v;
    }
} dp[2];

void Rotate() {
    int b[N][N];
    memset(b, 0, sizeof(b));
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            b[j][n - i + 1] = a[i][j];
        }
    }
    swap(n, m);
    memcpy(a, b, sizeof(b));
}

int bit[N], bas[N];

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            char c; cin >> c;
            a[i][j] = (c == '_');
        }
    }
    if(m > n) Rotate();

    // Init
    for(int i = 0; i <= 11; i++) {
        bit[i] = (i << 1);
        bas[i] = (1 << bit[i]);
    }

    // DP
    int now = 0, pre = 1;
    dp[now].insert(0, 1);

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= dp[now].idx; j++) dp[now].id[j] <<= 2;

        for(int j = 1; j <= m; j++) {
            swap(now, pre);
            dp[now].clear();

            for(int k = 1; k <= dp[pre].idx; k++) {
                int st = dp[pre].id[k], val = dp[pre].val[k];
                int lft = (st >> bit[j - 1]) & 3;
                int up = (st >> bit[j]) & 3;

                if(!a[i][j]) {
                    if(!lft && !up)
                        dp[now].insert(st, val);
                }
                else if(!lft && !up) {
                    if(a[i + 1][j] && a[i][j + 1])
                        dp[now].insert(st | bas[j - 1] | bas[j], val);
                    if(a[i + 1][j])
                        dp[now].insert(st | (2 * bas[j - 1]), val);
                    if(a[i][j + 1])
                        dp[now].insert(st | (2 * bas[j]), val);
                }
                else if(lft && up) {
                    if(lft == up && lft == 2)
                        dp[now].insert(st ^ (2 * bas[j - 1]) ^ (2 * bas[j]), val);
                }
                else if(!up) {
                    if(a[i][j + 1])
                        dp[now].insert(st ^ (lft * bas[j - 1]) ^ (up * bas[j]) ^ (lft * bas[j]), val);

                    if(lft == 1) {
                        dp[now].insert(st ^ (lft * bas[j - 1]), val);
                    }
                    else if(lft == 2) {
                        if(a[i + 1][j])
                            dp[now].insert(st - bas[j - 1], val);
                    }
                }
                else if(!lft) {
                    if(a[i + 1][j])
                        dp[now].insert(st ^ (lft * bas[j - 1]) ^ (up * bas[j]) ^ (up * bas[j - 1]), val);

                    if(up == 1) {
                        dp[now].insert(st ^ (up * bas[j]), val);
                    }
                    else if(up == 2) {
                        if(a[i][j + 1])
                            dp[now].insert(st - bas[j], val);
                    }
                }
            }
        }
    }

    for(int i = 1; i <= dp[now].idx; i++) {
        if(dp[now].id[i] == 0) {
            cout << dp[now].val[i];
            return 0;
        }
    }
    cout << 0;
    return 0;
}