题解:P10581 [蓝桥杯 2024 国 A] 重复的串

· · 题解

题目传送门

大家好我不会 KMP,所以我使用了 ACAM + 矩阵快速幂通过了此题。

这种做法可以扩展的多个串的情况。

考虑建立原串的 ACAM,建出字典图,记 i 添加字符 c 后转移到的节点为 tr_{i,c}。可以想到一个 DP:令 f_{i,j,k} 表示当前填到第 i 个字符,在 ACAM 上的第 j 个点上,已经出现了 k 次(0\leq k\leq 2)。设 tot 为 ACAM 上的总点数,那么对于 j\in [0,tot],若加入 tr_{j,c} 后不会出现新的,那么 \forall 0\leq k\leq 2,f_{j,k}\to f_{tr_{j,c},k};若会出现新的,则 \forall 0\leq k\leq 1,f_{j,k}\to f_{tr_{j,c},k+1}。那么答案就是 \sum_{j=0}^{tot}f_{n,j,2}

考虑直接 DP 的时间复杂度是 O(n|\Sigma|tot),过不了。那么我们只需要把转移方程改成矩阵再用矩阵快速幂算就行,时间复杂度 O(tot^3\log n),可以通过。

#include <bits/stdc++.h>
using namespace std;
#define il inline
typedef long long ll;
const int mod = 998244353;
int M;
struct mat
{
    ll a[210][210];
    friend mat operator * (const mat &n1, const mat &n2)
    {
        mat n3 = (mat){{{0}}};
        for(int i = 0;i < M;++i)
            for(int j = 0;j < M;++j)
                for(int k = 0;k < M;++k)
                    n3.a[i][j] = (n3.a[i][j] + n1.a[i][k] * n2.a[k][j] % mod) % mod;
        return n3;
    }
};
il mat ppow(mat A, mat T, int n)
{
    for(;n > 0;T = T * T, n >>= 1)
        if(n & 1) A = A * T;
    return A;
}
int tr[40][26], cnt[40], fail[40], tot;
il void insert(string s)
{
    int root = 0;
    for(char c : s)
    {
        if(!tr[root][c - 'a']) tr[root][c - 'a'] = ++tot;
        root = tr[root][c - 'a'];
    }
    cnt[root]++;
}
queue<int> q;
il void ACAM()
{
    for(int i = 0;i < 26;++i)
        if(tr[0][i]) q.push(tr[0][i]);
    while(q.size())
    {
        int u = q.front();
        q.pop();
        cnt[u] += cnt[fail[u]];
        for(int i = 0;i < 26;++i)
        {
            if(tr[u][i])
            {
                fail[tr[u][i]] = tr[fail[u]][i];
                q.push(tr[u][i]);
            }
            else tr[u][i] = tr[fail[u]][i];
        }
    }
}
il int calc(int j, int k)
{
    return j + (tot + 1) * k;
}
il void print(mat A)
{
    for(int i = 0;i < M;++i)
    {
        for(int j = 0;j < M;++j) cout << A.a[i][j] << " ";
        cout << "\n";
    }
}
mat A, T;
int main()
{
    int n;
    string s;
    cin >> s >> n;
    insert(s);
    ACAM();
    M = calc(tot, 2) + 1;
    A.a[0][0] = 1;
    for(int j = 0;j <= tot;++j)
    {
        for(int c = 0;c < 26;++c)
        {
            for(int k = 0;k <= 2;++k)
            {
                int x = k + cnt[tr[j][c]];
                if(x <= 2) T.a[calc(j, k)][calc(tr[j][c], x)]++;
            }
        }
    }
//  print(T);
    A = ppow(A, T, n);
//  print(A);
    ll ans = 0;
    for(int i = calc(0, 2);i <= calc(tot, 2);++i) ans = (ans + A.a[0][i]) % mod;
    cout << ans;
    return 0;
}