回文自动机学习笔记

· · 个人记录

一、基本定义

回文自动机 (PAM) 是一种可以处理有关单个串内的回文信息的算法。

(1) 结构

一个 PAM 由两棵树组成,两棵树的根节点分别为 01。其中,0 节点的子树存储长度为偶数的回文串,1 节点的子树存储长度为奇数的回文串。

(2) 点

一个 PAM 中最多有 |S|+2 个节点。每个节点到它所对应的根(节点 01)的链都代表了若干个在原串中的回文串。每个节点有一个深度 dep_x,代表它所对应的回文串长度。

(3) 边

PAM 上的边是有向边,并且有一个边权 c,其中 c 代表一个字符。设这条边从 x 指向 v,边权为 c,那么这条边就代表在 x 代表的回文串左右两端各加上一个字符 c 构成回文串 v。特别地,若 x=1 则这条边只代表往空串中加入一个字符 c 构成回文串 v。因此,为了方便计算,我们设 dep_1=-1

(4) fail 指针

一个节点 x 的 fail 指针代表节点 x 所代表的回文串的最长的且为回文串的真后缀(即不能等于自身)所对应的节点(01 号节点除外)。若这个节点不存在,fail_x=0

一开始 fail_0=fail_1=1

二、构造

假如我们要插入的串 S=abbcac,记上一个插入的节点为 las,一开始 las=1,那么插入过程可以这样表示:

一开始的 PAM 时这样的(黄色的边表示 fail,黑色数字表示深度,黑色边表示边):

然后插入字符 a,此时 S_{1-dep_{las}-1}=S_1,因此新节点 2 直接插在 1 下方,12 的边权为 adep_2=dep_1+2=1,但我们暂时不插入 2 节点,原因过一会儿讲 。

那么从 fail_{las}=1 开始跳,因为 S_{1-dep_1-1}=S_1,所以 fail_2=son_{1,a},这时暂时不插入 2 节点的作用就体现出来了。如果这时候已经插入 2 节点,fail_2 将会指向自身,就不符合上面 fail 的定义了。由于 2 节点还没有插入,son_{1,a} 不存在,所以 fail_2=0。那么上一个插入的点就为 2,将 las 置为 2

此时的 PAM 如下图:

对于字符 b,因为此时 las=2,首先从节点 2 开始跳。

因为 S_{2-dep_2-1}\ne S_2,执行 las=fail_{las},此时 las=0

因为 S_{2-dep_0-1}\ne S_2,执行 las=fail_{las},此时 las=1

因为 S_{2-dep_1-1}=S_2,此时 son_{1,b} 还不存在,新节点 3 直接插入在 1 下方。此时 fail_3dep_3 的求法同上,这里不展开讲了。此时将 las 置为 3

此时 PAM 结构如下:

对于下一个字符 b,此时 las=3,从节点 3 开始跳:

因为 S_{3-dep_3-1}\ne S_3,执行 las=fail_{las},此时 las=0

因为 S_{3-dep_0-1}=S_3,此时 son_{0,b} 还不存在,新节点 4 可以插入在 0 下方。那么 dep_4=dep_0+2=2,同上,我们暂时不插入节点 4。我们记 t=fail_{las}=1,从 t 开始跳:

因为 S_{3-dep_1-1}=S_3,所以 fail_4=son_{1,b}=3。将 las 置为 4。此时 PAM 结构如下:

对于下一个字符 c,此时 las=4,从节点 4 开始跳:

因为 S_{4-dep_4-1}\ne S_4,执行 las=fail_{las},此时 las=3

因为 S_{4-dep_3-1}\ne S_4,执行 las=fail_{las},此时 las=0

因为 S_{4-dep_0-1}\ne S_4,执行 las=fail_{las},此时 las=1

因为 S_{4-dep_1-1}=S_4,此时 son_{1,c} 还不存在,新节点 5 可以插入在 1 下方。那么 dep_5=dep_1+2=1。我们暂时不插入节点 5,记 t=fail_{las}=1,从 t 开始跳:

因为 S_{4-dep_1-1}=S_4,而此时由于没插入节点 5son_{1,c} 还不存在,因此 fail_5=0。将 las 置为 5。此时 PAM 结构如下:

对于下一个字符 a,此时 las=5,从节点 5 开始跳:

因为 S_{5-dep_5-1}\ne S_5,执行 las=fail_{las},此时 las=0

因为 S_{5-dep_0-1}\ne S_5,执行 las=fail_{las},此时 las=1

因为 S_{5-dep_1-1}=S_5,此时 son_{1,a} 已经存在了(节点 2),因此我们把 las 置为 son_{1,a}(如果题目需要维护 size,就将 2 节点的 size 加一,此时 size_2=2)。

对于下一个字符 c,此时 las=2,从节点 2 开始跳:

因为 S_{6-dep_2-1}=S_6,此时 son_{2,c} 还不存在,新节点 6 可以插入在 2 下方。那么 dep_6=dep_2+2=3。我们暂时不插入节点 6,记 t=fail_{las}=0,从 t 开始跳:

因为 S_{6-dep_0-1}\ne S_6,执行 t=fail_t,此时 t=1

因为 S_{6-dep_1-1}=S_6,所以 fail_6=son_{1,c}=5。将 las 置为 6。此时构造结束,PAM 结构如下:

Code(其中 nw 代表当前插入的 S 下标):

int gfail (int x) {
    for (; s[nw - dep[x] - 1] != s[nw] ; x = fail[x]) ;
    return x ;
}
void ext (int x) {
    las = gfail (las) ;
    if (!ch[las][x]) {
        dep[++tot] = dep[las] + 2 ;
        fail[tot] = ch[gfail (fail[las])][x] ;
        ch[las][x] = tot ;
    }
    sz[ch[las][x]]++ , las = ch[las][x] ;
}

三、额外操作

其实在构建 PAM 的时候可以额外求出一个 trans 指针,代表该串最长且不超过该串长度一半的后缀回文串所对应的节点。这在许多题中有很大的作用。

Code:

int gfail (int x) {
    for (; s[nw - dep[x] - 1] != s[nw] ; x = fail[x]) ;
    return x ;
}
void ext (int x) {
    las = gfail (las) ;
    if (!ch[las][x]) {
        dep[++tot] = dep[las] + 2 ;
        fail[tot] = ch[gfail (fail[las])][x] ;
        if (dep[tot] <= 2) tr[tot] = fail[tot] ;
        else {
            int tmp = tr[las] ;
            for (; s[nw - dep[tmp] - 1] != s[nw] || (dep[tmp] + 2) * 2 > dep[tot] ; tmp = fail[tmp]) ;
            tr[tot] = ch[tmp][x] ;
        }
        ch[las][x] = tot ;
    }
    las = ch[las][x] ;
}

四、例题

P4555 [国家集训队]最长双回文串

对于正反串分别建立 PAM,用一个桶记一下即可。

Code:

#include <bits/stdc++.h>
using namespace std ;
const int MAXN = 1e5 + 10 ;
int ch[MAXN][26] , fail[MAXN] , dep[MAXN] , las = 1 , tot = 1 , nw ;
int n , t1[MAXN] , t2[MAXN] , ans ;
char s[MAXN] ;
int gfail (int x) {
    for (; s[nw - dep[x] - 1] != s[nw] ; x = fail[x]) ;
    return x ;
}
void ext (int x) {
    las = gfail (las) ;
    if (!ch[las][x]) {
        dep[++tot] = dep[las] + 2 ;
        fail[tot] = ch[gfail (fail[las])][x] ;
        ch[las][x] = tot ;
    }
    t2[nw] = max (t2[nw] , dep[ch[las][x]]) ;
    las = ch[las][x] ;
}
int gf2 (int x) {
    for (; s[nw + dep[x] + 1] != s[nw] ; x = fail[x]) ;
    return x ;
}
void ext2 (int x) {
    las = gf2 (las) ;
    if (!ch[las][x]) {
        dep[++tot] = dep[las] + 2 ;
        fail[tot] = ch[gf2 (fail[las])][x] ;
        ch[las][x] = tot ;
    }
    t1[nw] = max (t1[nw] , dep[ch[las][x]]) ;
    las = ch[las][x] ;
}
int main () {
    fail[0] = fail[1] = 1 ; dep[1] = -1 ;
    scanf ("%s" , s + 1) ;
    n = strlen (s + 1) ;
    for (nw = 1 ; nw <= n ; nw++) ext (s[nw] - 'a') ;
    memset (fail , 0 , sizeof (fail)) ;
    memset (dep , 0 , sizeof (dep)) ; memset (ch , 0 , sizeof (ch)) ;
    fail[0] = fail[1] = 1 ; dep[1] = -1 ; tot = las = 1 ;
    for (nw = n ; nw ; nw--) ext2 (s[nw] - 'a') ;
    for (int i = 1 ; i < n ; i++)
        ans = max (ans , t2[i] + t1[i + 1]) ;
    printf ("%d\n" , ans) ;
    return 0 ;
}