kmp

· · 个人记录

(学字符串肯定就要学这个这个啦)

KMP算法

首先我们要知道前缀函数这个东西。

前缀函数

其实前缀这个东西太简单了,就是一个串前面有多少个字符所形成的字符串。而前缀函数呢。这是oiwiki里的定义:

给定一个长度为 𝑛 的字符串 𝑠,其前缀函数被定义为一个长度为 𝑛 的数组 \pi。 其中 \pi[i] 的定义是:

  • 如果子串 s[0\dots i] 有一对相等的真前缀与真后缀:s[0\dots k-1]s[i - (k - 1) \dots i],那么 \pi[i] 就是这个相等的真前缀(或者真后缀,因为它们相等)的长度,也就是 \pi[i]=k
  • 如果不止有一对相等的,那么
  • 如果没有相等的,那么 \pi[i]=0

简单来说 \pi[i] 就是,子串 s[0\dots i] 最长的相等的真前缀与真后缀的长度。

那怎么求又是个问题。O(N^3) 是显然的,枚举即可。考虑优化:

第一个重要的观察是相邻的前缀函数值至多增加 1。

所以显然可以优化成 O(N^2)。 代码如下:(oiwiki)。

vector<int> prefix_function(string s) {
    int n = (int)s.length();
    vector<int> pi(n);
    for(int i = 1; i < n; i++)
        for(int j = pi[i - 1] + 1; j >= 0; j--)  // improved: j=i => j=pi[i-1]+1
        if (s.substr(0, j) == s.substr(i - j + 1, j)) {
            pi[i] = j;
            break;
        } 
    return pi;
}

当然,还有更好的优化方法。有一个很简单的转移方程:

𝑗(𝑛) =𝜋[𝑗(𝑛−1) −1], (𝑗(𝑛−1) >0)

怎么得来的呢?我们在第二个优化中,知道最好情况就是当 s[i + 1] == s[\pi[i]] 时,只需比较一次即可。因此我们可以观察一些性质: 可以发现什么呢?就是

如果我们找到了这样的长度 𝑗 j,那么仅需要再次比较 𝑠[𝑖 +1] 和 𝑠[𝑗]。如果它们相等,那么就有 𝜋[𝑖 +1] =j+1。否则,我们需要找到子串 𝑠[0…𝑖] 仅次于 𝑗 j 的第二长度

也就是说 𝑗 j 等价于子串 𝑠[𝜋[𝑖] −1] 的前缀函数值,对应于上图下半部分,即 𝑗 =𝜋[𝜋[𝑖] −1]。同理,次于 𝑗 的前缀函数值,𝑗(2) =𝜋[𝑗 −1] 显然我们可以得到一个关于 j 的状态转移方程:𝑗(𝑛) =𝜋[𝑗(𝑛−1) −1]

所以,现在我们就能优化成 O(n) 的复杂度了。(可以用回退的思想去理解)

vector<int> prefix_function(string s) {
    int n = (int)s.length();
    vector<int> pi(n);
    for (int i = 1; i < n; i++) {
        int j = pi[i - 1];
        while (j > 0 && s[i] != s[j]) j = pi[j - 1];
        if (s[i] == s[j]) j++;
        pi[i] = j;
    }
    return pi;
}

KMP

学了前缀函数的意义,就是为了更好地去理解 KMP 算法。

首先我们算出模式串 P 的前缀函数,然后用双指针开始扫。(令文本串为 T, 模式串为 P

i < n 时:

代码实现:

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
int pi[N];
int ans[N], ansi;
int main(){
    string T, P;
    cin >> T >> P;
    int n = T.size(), m = P.size();
    // T = ' ' + T, P = ' ' + P; 
    for(int i = 1; i <= m - 1; i++){
        int j = pi[i - 1];
        while(j > 0 && P[i] != P[j])j = pi[j - 1];
        if(P[i] == P[j])j++;
        pi[i] = j;
    }
    int i = 0, j = 0;
    while(i <= n){
        if(T[i] == P[j])i++, j++;
        if(j == m)ans[++ansi] = i - m, j = pi[j - 1];
        if(T[i] != P[j])
            if(j > 0)j = pi[j - 1];
            else i++;
    }
    for(int i = 1; i <= ansi; i++)cout << ans[i] + 1 << endl;
    for(int i = 0; i <= m - 1; i++)cout << pi[i] << " ";
    return 0;
}