题解:P2289 [HNOI2004] 邮递员

· · 题解

本文节选自插头 DP 学习笔记。

2.1 P5056 【模板】插头 DP

考虑用括号表示法刻画“哈密顿回路”的限制。具体而言,如下图所示,以中间绿色的分界线把哈密顿回路切开,然后考虑分界线上方部分的切面。

显然,分界线上方的哈密顿回路分别构成了若干个连通分量,且每个连通分量均为链的形式。我们将链左端的切面转化为左括号 \texttt{(},将链右端的切面转化为右括号 \texttt{)},其余切面转化为占位符 \texttt{\#}。那么图中的哈密顿回路的上半部分就可以用 \texttt{(\#(\#\#))\#\#} 表示。(一个连通分量列数更小的端点为左括号,列数更大的端点为右括号)

分界线上每个切面的状态就称为插头。因为当转移到 (x, y) 时,分界线只有两个拐点,且都在 (x, y) 处,所以我们称 (x, y) 水平方向的切面为下插头,竖直方向的切面为右插头。

容易发现,哈密顿回路能用括号序列来刻画,关键就在于 任意两个连通分量的端点不会相交。这和括号匹配的性质恰好吻合。

而带占位符的括号序列可以使用三进制数来表示。为了减小常数,我们将三进制数按照四进制数来存储,因为二的正整数次幂作为进制可以使用位运算优化常数。但是有一个新的问题:转化为四进制数后,状态的值域过大。此时可以使用哈希表来存储状态。

设计插头 DP 为:\mathrm{dp}_{x, y, S} 表示当前考虑到位置 (x, y),分界线的插头状态为 S 的方案数。

接下来考虑转移。外层肯定是依次枚举位置 (x, y),重点在于内层转移。

首先障碍格的转移在判断合法性后直接继承即可。其余格子在转移的时候需要对右插头、下插头的括号类型进行分类讨论:

注意,换行的时候需要将所有状态集体左移 2 位(对应着四进制下左移 1 位),以匹配新的分界点形态。同时在插头延伸的时候需要判断延伸的方向是否有障碍,否则在进行换行的时候会引发错误。

其中,最后两种情况的分类讨论可以画图进行理解。而后两种的转移则需要用到括号串 \pm 1 转化的性质去做。

如图所示,这是“右插头是右括号,下插头是右括号”情况的解释。

时间复杂度 O(nm3^n)

一些有关代码的技巧:

#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 15, M = 2e6;

struct HashTable{
    int B = 1999993, h[M], id[M], idx, ne[M];
    ll val[M];

    void clear() {
        memset(h, 0, sizeof(h));
        for(int i = 0; i <= idx; i++) {
            val[i] = id[i] = ne[i] = 0;
        }
        idx = 0;
    }

    void insert(int st, ll v)
    {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                val[i] += v;
                return;
            }
        }

        val[++idx] = v;
        id[idx] = st;
        ne[idx] = h[st % B];
        h[st % B] = idx;
    }
} dp[2];

int n, m, a[N][N], ex, ey;

int bit[N], bas[N];

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            char c; cin >> c;
            if(c == '.') {
                ex = i, ey = j;
                a[i][j] = 1;
            }
        }
    }

    // Init
    for(int i = 0; i < N; i++) {
        bit[i] = (i << 1); bas[i] = (1 << bit[i]);
    }

    // DP
    int now = 0, pre = 1;
    dp[now].insert(0, 1);

    ll ans = 0;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= dp[now].idx; j++) dp[now].id[j] <<= 2;

        for(int j = 1; j <= m; j++) {
            swap(now, pre);
            dp[now].clear();

            for(int k = 1; k <= dp[pre].idx; k++) {
                int st = dp[pre].id[k]; ll val = dp[pre].val[k];
                int rht = (st >> bit[j - 1]) & 3;
                int dwn = (st >> bit[j]) & 3;

                if(!a[i][j]) {
                    if(!rht && !dwn)
                        dp[now].insert(st, val);
                }
                else if(!rht && !dwn) {
                    if(a[i + 1][j] && a[i][j + 1])
                        dp[now].insert(st | (1 << bit[j - 1]) | (2 << bit[j]), val);
                }
                else if(!dwn) {
                    if(a[i][j + 1]) dp[now].insert(st + rht * (bas[j] - bas[j - 1]), val);
                    if(a[i + 1][j]) dp[now].insert(st, val);
                }
                else if(!rht) {
                    if(a[i + 1][j]) dp[now].insert(st + dwn * (bas[j - 1] - bas[j]), val);
                    if(a[i][j + 1]) dp[now].insert(st, val);
                }
                else {

                    if(rht == 1 && dwn == 2) {
                        if(i == ex && j == ey)
                            ans += val;
                    }
                    else if(rht == 2 && dwn == 1) {
                        dp[now].insert(st - rht * bas[j - 1] - dwn * bas[j], val);
                    }
                    else if(rht == 1 && dwn == 1) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j; ; p++) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst -= bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                    else if(rht == 2 && dwn == 2) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j - 1; ; p--) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst += bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                }
            }
        }
    }

    cout << ans;
    return 0;
}

拓展:如果题目中不限制一个哈密顿回路,而是可以有多个哈密顿回路的话,就不用记录每个括号到底是左还是右了,因为不需要区分合并时是否形成回路。可以用二进制状压存储括号的存在性,达到 O(nm2^m) 的复杂度。

2.2 P2289 [HNOI2004] 邮递员

下文中称行数为 n,列数为 m,与题目中的定义相反。

考虑观察题目中回路的形态:

因此当 n, m > 1 的时候路线的长度(边数)一定得是 n\times m。又由于三角形斜边长度大于直角边,所以不能斜着走。容易发现这就是让我们数有多少个有向哈密顿回路,可以直接套用 2.1 的做法,最后答案乘个 2。时间复杂度 O(nm3^n)

注意答案是个很大的数,需要使用 __int128 存储。

#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 30, M = 2e6;

void write(i128 x) {
    if(x < 0) {x = -x; putchar('-'); }
    if(x < 10) { putchar('0' + x); return; }
    write(x / 10); putchar('0' + x % 10);
}

struct HashTable{
    int B = 1999993, h[M], id[M], idx, ne[M];
    i128 val[M];

    void clear() {
        memset(h, 0, sizeof(h));
        for(int i = 0; i <= idx; i++) {
            val[i] = id[i] = ne[i] = 0;
        }
        idx = 0;
    }

    void insert(int st, i128 v)
    {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                val[i] += v;
                return;
            }
        }

        val[++idx] = v;
        id[idx] = st;
        ne[idx] = h[st % B];
        h[st % B] = idx;
    }
} dp[2];

int n, m, a[N][N], ex, ey;

int bit[N], bas[N];

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> m >> n;
    if(n == 1 || m == 1) {
        cout << 1;
        return 0;
    }

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            char c = '.';
            if(c == '.') {
                ex = i, ey = j;
                a[i][j] = 1;
            }
        }
    }

    // Init
    for(int i = 0; i < N; i++) {
        bit[i] = (i << 1); bas[i] = (1 << bit[i]);
    }

    // DP
    int now = 0, pre = 1;
    dp[now].insert(0, 1);

    i128 ans = 0;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= dp[now].idx; j++) dp[now].id[j] <<= 2;

        for(int j = 1; j <= m; j++) {
            swap(now, pre);
            dp[now].clear();

            for(int k = 1; k <= dp[pre].idx; k++) {
                int st = dp[pre].id[k]; i128 val = dp[pre].val[k];
                int rht = (st >> bit[j - 1]) & 3;
                int dwn = (st >> bit[j]) & 3;

                if(!a[i][j]) {
                    if(!rht && !dwn)
                        dp[now].insert(st, val);
                }
                else if(!rht && !dwn) {
                    if(a[i + 1][j] && a[i][j + 1])
                        dp[now].insert(st | (1 << bit[j - 1]) | (2 << bit[j]), val);
                }
                else if(!dwn) {
                    if(a[i][j + 1]) dp[now].insert(st + rht * (bas[j] - bas[j - 1]), val);
                    if(a[i + 1][j]) dp[now].insert(st, val);
                }
                else if(!rht) {
                    if(a[i + 1][j]) dp[now].insert(st + dwn * (bas[j - 1] - bas[j]), val);
                    if(a[i][j + 1]) dp[now].insert(st, val);
                }
                else {

                    if(rht == 1 && dwn == 2) {
                        if(i == ex && j == ey)
                            ans += val;
                    }
                    else if(rht == 2 && dwn == 1) {
                        dp[now].insert(st - rht * bas[j - 1] - dwn * bas[j], val);
                    }
                    else if(rht == 1 && dwn == 1) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j; ; p++) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst -= bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                    else if(rht == 2 && dwn == 2) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j - 1; ; p--) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst += bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                }
            }
        }
    }

    write(2 * ans);
    return 0;
}