回文自动机学习笔记
一、基本定义
回文自动机 (PAM) 是一种可以处理有关单个串内的回文信息的算法。
(1) 结构
一个 PAM 由两棵树组成,两棵树的根节点分别为
(2) 点
一个 PAM 中最多有
(3) 边
PAM 上的边是有向边,并且有一个边权
(4) fail 指针
一个节点
一开始
二、构造
假如我们要插入的串
一开始的 PAM 时这样的(黄色的边表示 fail,黑色数字表示深度,黑色边表示边):
然后插入字符
那么从
此时的 PAM 如下图:
对于字符
因为
因为
因为
此时 PAM 结构如下:
对于下一个字符
因为
因为
因为
对于下一个字符
因为
因为
因为
因为
因为
对于下一个字符
因为
因为
因为
对于下一个字符
因为
因为
因为
Code(其中 nw 代表当前插入的
int gfail (int x) {
for (; s[nw - dep[x] - 1] != s[nw] ; x = fail[x]) ;
return x ;
}
void ext (int x) {
las = gfail (las) ;
if (!ch[las][x]) {
dep[++tot] = dep[las] + 2 ;
fail[tot] = ch[gfail (fail[las])][x] ;
ch[las][x] = tot ;
}
sz[ch[las][x]]++ , las = ch[las][x] ;
}
三、额外操作
其实在构建 PAM 的时候可以额外求出一个
Code:
int gfail (int x) {
for (; s[nw - dep[x] - 1] != s[nw] ; x = fail[x]) ;
return x ;
}
void ext (int x) {
las = gfail (las) ;
if (!ch[las][x]) {
dep[++tot] = dep[las] + 2 ;
fail[tot] = ch[gfail (fail[las])][x] ;
if (dep[tot] <= 2) tr[tot] = fail[tot] ;
else {
int tmp = tr[las] ;
for (; s[nw - dep[tmp] - 1] != s[nw] || (dep[tmp] + 2) * 2 > dep[tot] ; tmp = fail[tmp]) ;
tr[tot] = ch[tmp][x] ;
}
ch[las][x] = tot ;
}
las = ch[las][x] ;
}
四、例题
P4555 [国家集训队]最长双回文串
对于正反串分别建立 PAM,用一个桶记一下即可。
Code:
#include <bits/stdc++.h>
using namespace std ;
const int MAXN = 1e5 + 10 ;
int ch[MAXN][26] , fail[MAXN] , dep[MAXN] , las = 1 , tot = 1 , nw ;
int n , t1[MAXN] , t2[MAXN] , ans ;
char s[MAXN] ;
int gfail (int x) {
for (; s[nw - dep[x] - 1] != s[nw] ; x = fail[x]) ;
return x ;
}
void ext (int x) {
las = gfail (las) ;
if (!ch[las][x]) {
dep[++tot] = dep[las] + 2 ;
fail[tot] = ch[gfail (fail[las])][x] ;
ch[las][x] = tot ;
}
t2[nw] = max (t2[nw] , dep[ch[las][x]]) ;
las = ch[las][x] ;
}
int gf2 (int x) {
for (; s[nw + dep[x] + 1] != s[nw] ; x = fail[x]) ;
return x ;
}
void ext2 (int x) {
las = gf2 (las) ;
if (!ch[las][x]) {
dep[++tot] = dep[las] + 2 ;
fail[tot] = ch[gf2 (fail[las])][x] ;
ch[las][x] = tot ;
}
t1[nw] = max (t1[nw] , dep[ch[las][x]]) ;
las = ch[las][x] ;
}
int main () {
fail[0] = fail[1] = 1 ; dep[1] = -1 ;
scanf ("%s" , s + 1) ;
n = strlen (s + 1) ;
for (nw = 1 ; nw <= n ; nw++) ext (s[nw] - 'a') ;
memset (fail , 0 , sizeof (fail)) ;
memset (dep , 0 , sizeof (dep)) ; memset (ch , 0 , sizeof (ch)) ;
fail[0] = fail[1] = 1 ; dep[1] = -1 ; tot = las = 1 ;
for (nw = n ; nw ; nw--) ext2 (s[nw] - 'a') ;
for (int i = 1 ; i < n ; i++)
ans = max (ans , t2[i] + t1[i + 1]) ;
printf ("%d\n" , ans) ;
return 0 ;
}