题解:AT_agc020_e [AGC020E] Encoding Subsets
ljk8886
·
·
题解
题目传送门 AT
题目大意
一个长度为 n 的 01 串 s 的压缩方案数。
## 题目分析
### 记号说明
对于两个字符串 $s$ 和 $t$,其中 $t$ 是 $s$ 的 **子串**。
1. 记 $s - t$ 为在 $s$ 中去掉 $t$ 后的字符串。
2. 对于一个正整数 $k$,$k \times t$ 表示 $k$ 个 $t$ 拼接起来。
### 分析
设 $f(s)$ 代表 $s$ 的压缩方案数。
考虑转移。对于 $s$ 的最后一个字符 $c$,有两种情况:
1. $c$ 不参与压缩时:
- $c = 0$ 时,$f(s) = f(s - c)$。
- $c = 1$ 时,$f(s) = f(s - c) \times 2$,因为可以 $1 \to 0$。
2. $c$ 参与压缩时:我们可以 $\mathcal{O}(n ^ 2)$ 暴力枚举 $c$ 要被压缩的最后一段。
设压缩的字符串为 $t$,压缩了 $k$ 个,那么 $f(s) = f(t) * f(s - t \times k)$。
由于状态数远远达不到上界,所以可以使用 **记忆化搜索**,时间复杂度 $\mathcal{O}(能过)$。
因为 $n \le 100$,我们可以用 `__int128` 来表示 $01$ 串。
## code
```cpp
#include <bits/stdc++.h>
#define ft first
#define sd second
#define endl '\n'
#define pb push_back
#define md make_pair
#define gc() getchar()
#define pc(ch) putchar(ch)
#define umap unordered_map
#define pque priority_queue
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 bint;
typedef pair<int, int> pii;
typedef pair<pii, int> pi1;
typedef pair<pii, pii> pi2;
const ll INF = 0x3f3f3f3f;
const db Pi = acos(-1.0);
inline ll read()
{
ll res = 0, f = 1; char ch = gc();
while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = gc();
while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + (ch ^ 48), ch = gc();
return res * f;
}
inline void write(ll x)
{
if (x < 0) x = -x, pc('-');
if (x > 9) write(x / 10);
pc(x % 10 + '0');
}
inline void writech(ll x, char ch) { write(x), pc(ch); }
const ll mod = 998244353;
const int N = 1e2 + 5;
map<bint, ll> dp[N];
void Add(ll &a, ll b) { a += b, a %= mod; }
ll dfs(int n, bint s)
{
if (n == 0) return 1;
if (dp[n].count(s)) return dp[n][s];
// c 不参与压缩
ll res = dfs(n - 1, s >> 1) * ((s & 1) + 1) % mod;
// f(s) = f(s - c), c = 0
// f(s - c) * 2, c = 1
bint tmp = 0, t = 0;
for (int i = 1; i <= n; i++) // i 是 t 的长度
{
tmp = (tmp << 1) | 1;
t = s & tmp;
for (int j = i * 2; j <= n; j += i) // 看一下能压缩多少个
{
t &= (s >> (j - i));
Add(res, dfs(i, t) * dfs(n - j, (s >> j)) % mod);
}
}
return dp[n][s] = res;
}
int main()
{
string s; cin >> s;
int n = s.size();
bint S = 0;
for (int i = n - 1; i >= 0; i--) S = (S << 1 | (s[i] - '0'));
write(dfs(n, S));
return 0;
}
```