KMP 技巧

· · 个人记录

依照 oi-wiki 上的文段所言,KMP问题基于前缀函数。
本文就以此给出自己的理解与具体实现代码

using namespace std;
const int N = 1e6 + 50;
char a[N], b[N];
int p[N]; // 即next,记录前j个字符中的 最长的相同的前缀和后缀
int main()
{
    cin >>(a + 1) >>(b + 1);
    int n = strlen(a + 1), m = strlen(b + 1);
    int j = 0; p[1] = 0;
    for(int i = 2; i <= m; i ++) {
        while(j > 0 && b[j + 1] != b[i]) j = p[j];
        if(b[j + 1] == b[i]) j ++;
        p[i] = j;
    }
    j = 0; 
    for(int i = 1; i <= n; i ++) {
        while(j > 0 && b[j + 1] != a[i]) j = p[j];
        if(b[j + 1] == a[i]) j ++;
        if(j == m) {cout <<i - m + 1 <<endl; j = p[j];}
    }
    for(int i = 1; i <= m; i ++) cout <<p[i] <<" ";
    return 0;
}
using namespace std;
const int N = 1e6 + 60;
char a[N]; int n, p[N]; long long ans;
int main() 
{
    cin >>n >>(a + 1);
    int j = 0;
    for(int i = 2; i <= n; i ++) {
        while(j > 0 && (a[j + 1] != a[i])) j = p[j];
        j += (a[j + 1] == a[i]);
        p[i] = j;
    }
// 递归求最小的j(j>0), 求最小的前缀长度
// next[i],next[next[i]]next[next[next[i]]]......都是这个前缀串i的公共前后缀,而且只有它们是公共前后缀
    for(int i = 1; i <= n; i ++) {
        j = i;
        while(p[j]) j = p[j];  
        if(p[i] != 0) p[i] = j; // 记忆化,类比路径压缩
        ans += i - j;
    }
    cout <<ans <<endl;
    return 0;
}
using namespace std;
const int N = 1e6 + 50;
char a[N];
int fail[N];
int main()
{
    cin >>(a + 1);
    int n = strlen(a + 1);
    int j = 0; fail[1] = 0;
    for(int i = 2; i <= n; i ++) {
        while(j > 0 && a[j + 1] != a[i]) j = fail[j];
        if(a[j + 1] == a[i]) j ++;
        fail[i] = j;
    }
    vector<int> ans(n + 1);
    for(int i = 1; i <= n; i ++) ans[fail[i]]++;
    for(int i = n; i > 1; i --) ans[fail[i]] += ans[i];
    for(int i = 1; i <= n; i ++) ans[i]++;
    for(int i = 1; i <= n; i ++) cout <<ans[i] <<" ";
    return 0;
}
using namespace std;
const int N = 2 * 1e7 + 50, Inf = 0x3f3f3f3f;
i64 p[N], z[N];
void Z(char *s, int n)
{
    for(int i = 1; i <= n; i ++) z[i] = 0;
    z[1] = n;
    for(int i = 2, l = 0, r = 0; i <= n; i ++) {
        if(i <= r) z[i] = min(z[i - l + 1], r - i + 1);
        while(i + z[i] <= n && s[i + z[i]] == s[z[i] + 1]) z[i] ++;
        if(i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
    }
    return;
}
void exkmp(char *s, int n, char *t, int m)
{
    Z(t, m);
    for(int i = 1; i <= n; i++) p[i] = 0;
    for(int i = 1, l = 0, r = 0; i <= n; i++) {
        if (i <= r) p[i] = min(z[i-l+1], r - i + 1);
        while (i + p[i] <= n && s[i + p[i]] == t[p[i] + 1]) p[i] ++;
        if (i + p[i] - 1 > r) l = i, r = i + p[i] - 1;
    }
}
char s[N], t[N];
int main()
{
    cin >>(s + 1) >>(t + 1);
    i64 n = strlen(s + 1), m = strlen(t + 1);
    exkmp(s, n, t, m);
    LL ans = 0;
    for(int i = 1; i <= m; i ++) ans ^= 1ll * i * (z[i] + 1);
    print(ans); putchar('\n'); ans = 0;
    for(int i = 1; i <= n; i ++) ans ^= 1ll * i * (p[i] + 1);
    print(ans); return 0;
}
using namespace std;
const int N = 1e6 + 50;
char s[N];
int n, fa[N][23], dep[N], m;
void KMP()
{
    int j = 0; dep[0] = 0; dep[1] = 1;
    for(int i = 2; i <= n; i ++) {
        while(j > 0 && s[j + 1] != s[i]) j = fa[j][0];
        j += (s[j + 1] == s[i]);
        fa[i][0] = j;
        dep[i] = dep[j] + 1; fa[i][0] = j;
    }
    for(int k = 1; k <= 21; k ++) 
        for(int i = 1; i <= n; i ++) 
            fa[i][k] = fa[fa[i][k - 1]][k - 1];
    return;
}
int LCA(int x, int y) // 用tarjan 和 树链剖分更快
{
    if(dep[x] < dep[y]) swap(x, y);
    for(int i = 21; i >= 0; i --) 
        if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
    // if(x == y) return x; 此处不能直接返回 ???
    for(int i = 21; i >= 0; i --) 
        if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
} 
int main()
{
    ios_base::sync_with_stdio,cin.tie(0),cout.tie(0);
    cin >>(s + 1); n = strlen(s + 1);
    KMP();
    cin >>m;
    for(int i = 1, x, y; i <= m; i ++) {
        cin >>x >>y;
        cout <<LCA(x, y) <<endl;
    }
    return 0;
}
  • 推论 :设字符串的长度为 len, 对于 任意 i \in [1, len], 若 i \text{mod} (i - next_i) = 0 && next_i > 0, 则前 i 个字符循环, 循环节的长度为 \dfrac{i}{i - next_i}

  • 证明 :

using namespace std;
const int N = 2 * 1e6 + 50;
char ch[N];
int fail[N];
int main()
{
    while(cin >>(ch + 1)) {
        memset(fail, 0, sizeof(fail));
        int n = strlen(ch + 1);
        if(ch[1] == '.' && n == 1) break;
        int j = 0;
        for(int i = 2; i <= n; i ++) {
            while(j > 0 && ch[j + 1] != ch[i]) j = fail[j];
            if(ch[j + 1] == ch[i]) j ++;
            fail[i] = j;
        } 
        if(fail[n] != 0 && n % (n - fail[n]) == 0) {
            cout <<(n / (n - fail[n])) <<endl;
        }
        else cout <<"1" <<endl;
    } 
    return 0;
}