[HAOI2015] 字符串拆分
KyTLAS
·
·
个人记录
首先发现 f_i=\sum_{j=1}^m f_{i-j}。
以 $m=3$ 为例:
$$
\begin{bmatrix}
f_n & f_{n-1} & f_{n-2}
\end{bmatrix}
\times
\begin{bmatrix}
1 & 1 & 0\\
1 & 0 & 1\\
1 & 0 & 0
\end{bmatrix}
=
\begin{bmatrix}
f_{n+1} & f_n & f_{n-1}
\end{bmatrix}
$$
撕烤如何计算 $g$。根据矩阵的结合律,先考虑转移矩阵的计算,最后再整体乘上初始矩阵。
考虑这样一个事实:$f(x+y)=G^{x+y}=G^x\times G^y=f(x)\times f(y)$,其中 $G$ 为转移矩阵。
基于此,可以 dp 求解原数字串的划分。
比如:$g(123)=g(12)\times f(3)+g(1)\times f(23)+f(123)$。
这样求解的前提是我们知道原串每个子段的 $f$ 值,这个可以递推求解。
设 $D_{i,j}=f(s_{i\sim j})$,则:
$$D_{i,j}=
\begin{cases}
f(s_i), & \mathrm{if}\ i=j\\
f(s_i\times 10^{j-i})\times D_{i+1,j}, & \mathrm{Otherwise}
\end{cases}
$$
则有:
$$
g_i=\sum_{j=0}^{i-1} g_j\times D_{j+1,i}
$$
预处理 $D_{i,j}$ 之前先算出来 $f(c\times 10^k)$,时间复杂度可以做到 $\mathcal{O(n^2m^3)}$,我这里写少了,只处理了 $f(10^k)$,不过开 O2 也可以过。
``` cpp
// Problem: P3176 [HAOI2015]数字串拆分
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3176
// Memory Limit: 125 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
using i64 = long long;
const int maxn = 505;
const int maxm = 5;
const i64 mod = 998244353;
int n,m;
char s[maxn];
struct matrix {
i64 g[maxm][maxm];
matrix() {
memset(g , 0 , sizeof(g));
}
void clear() {
memset(g , 0 , sizeof(g));
return ;
}
void init() {
for(int i = 0;i < m;++ i)
g[i][i] = 1;
return ;
}
void output() {
puts("TEST:");
for(int i = 0;i < m;++ i) {
for(int j = 0;j < m;++ j)
printf("%lld ",g[i][j]);
puts("");
}
return ;
}
matrix operator * (const matrix& p)const {
matrix c;
for(int k = 0;k < m;++ k)
for(int i = 0;i < m;++ i)
for(int j = 0;j < m;++ j)
(c.g[i][j] += g[i][k] * p.g[k][j] % mod) %= mod;
return c;
}
matrix operator + (const matrix& p)const {
matrix c;
for(int i = 0;i < m;++ i)
for(int j = 0;j < m;++ j)
c.g[i][j] = (g[i][j] + p.g[i][j]) % mod;
return c;
}
}f,g[maxn],pw[maxn],d[maxn][maxn];
matrix power(matrix x,int y) {
matrix ans;
ans.init();
for(;y;y >>= 1) {
if(y & 1)ans = ans * x;
x = x * x;
}
return ans;
}
int main() {
scanf("%s %d",s + 1,&m);
n = strlen(s + 1);
for(int i = 1;i <= n;++ i)
s[i] ^= '0';
for(int i = 1;i < m;++ i)
pw[0].g[i - 1][i] = 1;
for(int i = 0;i < m;++ i)
pw[0].g[i][0] = 1;
f.g[0][0] = 1;
for(int i = 1;i < n;++ i)
pw[i] = power(pw[i - 1] , 10);
for(int j = 1;j <= n;++ j)
for(int i = j;i;-- i) {
if(i == j)
d[i][j] = power(pw[0] , s[i]);
else
d[i][j] = power(pw[j - i] , s[i]) * d[i + 1][j];
}
for(int i = 1;i <= n;++ i) {
g[i] = d[1][i];
for(int j = i - 1;j;-- j)
g[i] = g[i] + g[j] * d[j + 1][i];
}
printf("%lld\n",(f * g[n]).g[0][0]);
return 0;
}
```