kmp
(学字符串肯定就要学这个这个啦)
KMP算法
首先我们要知道前缀函数这个东西。
前缀函数
其实前缀这个东西太简单了,就是一个串前面有多少个字符所形成的字符串。而前缀函数呢。这是oiwiki里的定义:
给定一个长度为
𝑛 的字符串𝑠 ,其前缀函数被定义为一个长度为𝑛 的数组\pi 。 其中\pi[i] 的定义是:
- 如果子串
s[0\dots i] 有一对相等的真前缀与真后缀:s[0\dots k-1] 和s[i - (k - 1) \dots i] ,那么\pi[i] 就是这个相等的真前缀(或者真后缀,因为它们相等)的长度,也就是\pi[i]=k ;- 如果不止有一对相等的,那么
- 如果没有相等的,那么
\pi[i]=0 。简单来说
\pi[i] 就是,子串s[0\dots i] 最长的相等的真前缀与真后缀的长度。
那怎么求又是个问题。
第一个重要的观察是相邻的前缀函数值至多增加 1。
所以显然可以优化成
vector<int> prefix_function(string s) {
int n = (int)s.length();
vector<int> pi(n);
for(int i = 1; i < n; i++)
for(int j = pi[i - 1] + 1; j >= 0; j--) // improved: j=i => j=pi[i-1]+1
if (s.substr(0, j) == s.substr(i - j + 1, j)) {
pi[i] = j;
break;
}
return pi;
}
当然,还有更好的优化方法。有一个很简单的转移方程:
怎么得来的呢?我们在第二个优化中,知道最好情况就是当
如果我们找到了这样的长度 𝑗 j,那么仅需要再次比较 𝑠[𝑖 +1] 和 𝑠[𝑗]。如果它们相等,那么就有 𝜋[𝑖 +1] =j+1。否则,我们需要找到子串 𝑠[0…𝑖] 仅次于 𝑗 j 的第二长度
也就是说 𝑗 j 等价于子串 𝑠[𝜋[𝑖] −1] 的前缀函数值,对应于上图下半部分,即 𝑗 =𝜋[𝜋[𝑖] −1]。同理,次于 𝑗 的前缀函数值,𝑗(2) =𝜋[𝑗 −1] 显然我们可以得到一个关于 j 的状态转移方程:𝑗(𝑛) =𝜋[𝑗(𝑛−1) −1]
所以,现在我们就能优化成
vector<int> prefix_function(string s) {
int n = (int)s.length();
vector<int> pi(n);
for (int i = 1; i < n; i++) {
int j = pi[i - 1];
while (j > 0 && s[i] != s[j]) j = pi[j - 1];
if (s[i] == s[j]) j++;
pi[i] = j;
}
return pi;
}
KMP
学了前缀函数的意义,就是为了更好地去理解
首先我们算出模式串
当 i < n 时:
- 若
T_i == P_j :说明匹配成功,这时i++, j++ - 若
j == m :说明找到了一个完全匹配,记录位置为i - m ,然后j = \pi_{j - 1} 继续寻找,因为\pi_{j - 1} 可以作为下一次的前缀。 - 若
T_i \neq P_j :当j > 0 时,利用\pi 回退,因为前缀函数的意义就是真前缀和真后缀最长的相同的部分,所以即使前面的不配对,后面相同的依然可以作为一个开头,此时我们就j = \pi_{j - 1} ,回退到上一个还能匹配的地方;若j = 0 时,即没有相同的,这是我们就将i++ , 即可。
代码实现:
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
int pi[N];
int ans[N], ansi;
int main(){
string T, P;
cin >> T >> P;
int n = T.size(), m = P.size();
// T = ' ' + T, P = ' ' + P;
for(int i = 1; i <= m - 1; i++){
int j = pi[i - 1];
while(j > 0 && P[i] != P[j])j = pi[j - 1];
if(P[i] == P[j])j++;
pi[i] = j;
}
int i = 0, j = 0;
while(i <= n){
if(T[i] == P[j])i++, j++;
if(j == m)ans[++ansi] = i - m, j = pi[j - 1];
if(T[i] != P[j])
if(j > 0)j = pi[j - 1];
else i++;
}
for(int i = 1; i <= ansi; i++)cout << ans[i] + 1 << endl;
for(int i = 0; i <= m - 1; i++)cout << pi[i] << " ";
return 0;
}