题解 CountTilings

· · 个人记录

咕咕咕,太复杂了懒得写题解,南外的回放就挺好。

https://archive.topcoder.com/ProblemStatement/pm/17306

是个极好的学习 DP 套 DP 的题!

格式化代码:

#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).end()
using namespace std;
const int mod = 1000000007;
#define getbit(a,b) (((a)>>(b))&1)
using ll = long long;
using ull = unsigned long long;
using vi = vector<int>;
int n, m;
unordered_map<ull, int> dp[8005];
bool flag[4100];
vi in_state;
ull expand[4100];
ull state_val[4100];
unordered_map<ull, vector<pair<ull, int> > > trans;
bool check11(int x) {
    for (int i = 0; i + 1 < n; i++)
        if (((x >> i) & 3) == 3)
            x ^= 3 << i;
    return !x;
}
vector<pair<ull, int> > &get_trans(ull now_states) {
    if (trans.count(now_states))
        return trans[now_states];
    map<ull, int> now_trans;
    for (int remove = 0; remove < (1 << n); remove++) {
        ull final_out_state = 0;
        for (int now_state_id = 0; now_state_id < in_state.size(); now_state_id++)
            if (getbit(now_states, now_state_id) && (remove & in_state[now_state_id]) == remove)
                final_out_state |= expand[remove | (in_state[now_state_id] ^ ((1 << n) - 1))];
        now_trans[final_out_state]++;
    }
    return trans[now_states] = vector<pair<ull, int> >(all(now_trans));
}
signed main() {
    cin >> n >> m;
    for (int i = 0; i < (1 << n); i++) {
        for (int j = 0; j + 1 < n; j++) {
            if (!((i >> j) & 3)) {
                flag[i] = 1;
                break;
            }
        }
        if (flag[i] == 0) {
            state_val[i] = 1ull << in_state.size();
            in_state.push_back(i);
            for (int j = 0; j <= i; j++)
                if ((j & i) == j && check11(i ^ j))
                    expand[j] |= state_val[i];
        }
    }
    dp[0][1ull << (in_state.size() - 1)] = 1;
    for (int i = 1; i <= m; i++)
        for (const pair<ull, int> &j : dp[i - 1])
            for (pair<ull, int> &k : get_trans(j.first))
                dp[i][k.first] = (dp[i][k.first] + (ll)k.second * j.second) % mod;
    ll ans = 0;
    for (const pair<ull, int> &i : dp[m])
        if (getbit(i.first, in_state.size() - 1))
            (ans += i.second) %= mod;
    cout << ans << endl;
    return 0;
}