[HAOI2015] 字符串拆分

· · 个人记录

首先发现 f_i=\sum_{j=1}^m f_{i-j}

以 $m=3$ 为例: $$ \begin{bmatrix} f_n & f_{n-1} & f_{n-2} \end{bmatrix} \times \begin{bmatrix} 1 & 1 & 0\\ 1 & 0 & 1\\ 1 & 0 & 0 \end{bmatrix} = \begin{bmatrix} f_{n+1} & f_n & f_{n-1} \end{bmatrix} $$ 撕烤如何计算 $g$。根据矩阵的结合律,先考虑转移矩阵的计算,最后再整体乘上初始矩阵。 考虑这样一个事实:$f(x+y)=G^{x+y}=G^x\times G^y=f(x)\times f(y)$,其中 $G$ 为转移矩阵。 基于此,可以 dp 求解原数字串的划分。 比如:$g(123)=g(12)\times f(3)+g(1)\times f(23)+f(123)$。 这样求解的前提是我们知道原串每个子段的 $f$ 值,这个可以递推求解。 设 $D_{i,j}=f(s_{i\sim j})$,则: $$D_{i,j}= \begin{cases} f(s_i), & \mathrm{if}\ i=j\\ f(s_i\times 10^{j-i})\times D_{i+1,j}, & \mathrm{Otherwise} \end{cases} $$ 则有: $$ g_i=\sum_{j=0}^{i-1} g_j\times D_{j+1,i} $$ 预处理 $D_{i,j}$ 之前先算出来 $f(c\times 10^k)$,时间复杂度可以做到 $\mathcal{O(n^2m^3)}$,我这里写少了,只处理了 $f(10^k)$,不过开 O2 也可以过。 ``` cpp // Problem: P3176 [HAOI2015]数字串拆分 // Contest: Luogu // URL: https://www.luogu.com.cn/problem/P3176 // Memory Limit: 125 MB // Time Limit: 1000 ms // // Powered by CP Editor (https://cpeditor.org) #include <bits/stdc++.h> using i64 = long long; const int maxn = 505; const int maxm = 5; const i64 mod = 998244353; int n,m; char s[maxn]; struct matrix { i64 g[maxm][maxm]; matrix() { memset(g , 0 , sizeof(g)); } void clear() { memset(g , 0 , sizeof(g)); return ; } void init() { for(int i = 0;i < m;++ i) g[i][i] = 1; return ; } void output() { puts("TEST:"); for(int i = 0;i < m;++ i) { for(int j = 0;j < m;++ j) printf("%lld ",g[i][j]); puts(""); } return ; } matrix operator * (const matrix& p)const { matrix c; for(int k = 0;k < m;++ k) for(int i = 0;i < m;++ i) for(int j = 0;j < m;++ j) (c.g[i][j] += g[i][k] * p.g[k][j] % mod) %= mod; return c; } matrix operator + (const matrix& p)const { matrix c; for(int i = 0;i < m;++ i) for(int j = 0;j < m;++ j) c.g[i][j] = (g[i][j] + p.g[i][j]) % mod; return c; } }f,g[maxn],pw[maxn],d[maxn][maxn]; matrix power(matrix x,int y) { matrix ans; ans.init(); for(;y;y >>= 1) { if(y & 1)ans = ans * x; x = x * x; } return ans; } int main() { scanf("%s %d",s + 1,&m); n = strlen(s + 1); for(int i = 1;i <= n;++ i) s[i] ^= '0'; for(int i = 1;i < m;++ i) pw[0].g[i - 1][i] = 1; for(int i = 0;i < m;++ i) pw[0].g[i][0] = 1; f.g[0][0] = 1; for(int i = 1;i < n;++ i) pw[i] = power(pw[i - 1] , 10); for(int j = 1;j <= n;++ j) for(int i = j;i;-- i) { if(i == j) d[i][j] = power(pw[0] , s[i]); else d[i][j] = power(pw[j - i] , s[i]) * d[i + 1][j]; } for(int i = 1;i <= n;++ i) { g[i] = d[1][i]; for(int j = i - 1;j;-- j) g[i] = g[i] + g[j] * d[j + 1][i]; } printf("%lld\n",(f * g[n]).g[0][0]); return 0; } ```