题解 CountTilings
咕咕咕,太复杂了懒得写题解,南外的回放就挺好。
https://archive.topcoder.com/ProblemStatement/pm/17306
是个极好的学习 DP 套 DP 的题!
格式化代码:
#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).end()
using namespace std;
const int mod = 1000000007;
#define getbit(a,b) (((a)>>(b))&1)
using ll = long long;
using ull = unsigned long long;
using vi = vector<int>;
int n, m;
unordered_map<ull, int> dp[8005];
bool flag[4100];
vi in_state;
ull expand[4100];
ull state_val[4100];
unordered_map<ull, vector<pair<ull, int> > > trans;
bool check11(int x) {
for (int i = 0; i + 1 < n; i++)
if (((x >> i) & 3) == 3)
x ^= 3 << i;
return !x;
}
vector<pair<ull, int> > &get_trans(ull now_states) {
if (trans.count(now_states))
return trans[now_states];
map<ull, int> now_trans;
for (int remove = 0; remove < (1 << n); remove++) {
ull final_out_state = 0;
for (int now_state_id = 0; now_state_id < in_state.size(); now_state_id++)
if (getbit(now_states, now_state_id) && (remove & in_state[now_state_id]) == remove)
final_out_state |= expand[remove | (in_state[now_state_id] ^ ((1 << n) - 1))];
now_trans[final_out_state]++;
}
return trans[now_states] = vector<pair<ull, int> >(all(now_trans));
}
signed main() {
cin >> n >> m;
for (int i = 0; i < (1 << n); i++) {
for (int j = 0; j + 1 < n; j++) {
if (!((i >> j) & 3)) {
flag[i] = 1;
break;
}
}
if (flag[i] == 0) {
state_val[i] = 1ull << in_state.size();
in_state.push_back(i);
for (int j = 0; j <= i; j++)
if ((j & i) == j && check11(i ^ j))
expand[j] |= state_val[i];
}
}
dp[0][1ull << (in_state.size() - 1)] = 1;
for (int i = 1; i <= m; i++)
for (const pair<ull, int> &j : dp[i - 1])
for (pair<ull, int> &k : get_trans(j.first))
dp[i][k.first] = (dp[i][k.first] + (ll)k.second * j.second) % mod;
ll ans = 0;
for (const pair<ull, int> &i : dp[m])
if (getbit(i.first, in_state.size() - 1))
(ans += i.second) %= mod;
cout << ans << endl;
return 0;
}