字符串:后缀自动机(SAM)
mydcwfy
·
·
个人记录
后缀自动机(SAM)
1. 主要思想
可以类比于 Trie,但它是一个有向无环图。
可以用来存储一个串的所有子串。
2. 性质
从起点出发,从普通边,可以得到所有子串。
同时,所有边和点的数量在 O(n) 级别。
终点就是包含后缀的点,同时每一个点对应的是不止一个子串。
但所有每一个点对应的子串满足:按长度降序排序后,我们发现每一个串都是前面一个串的后缀。
还有一种边:Link/Father。
这种边构成一棵树。
将一个点所对应的最短的子串的首字母去掉,得到的子串所对应的节点,再由原来节点指向这个节点。
首先定义 endpos(s) 为子串 s 在原串所有出现的位置(尾字母)下标集合。
如果有两个子串的 endpos 相同,则我们可以将其看为一个等价类。
SAM 的状态与所有的等价类一一对应。
证明1:如果 |s1|\leq|s2|,则 s1 是 s2 的后缀当且仅当 endpos(s1) \supseteq endpos(s2)。
这两个易证。
**证明2:**两个不同子串的 $endpos$ 要么包含要么无交集。
易得,如果不为交集,那么必有一个 $pos$ 使两个都满足,则必有一个是另一个的子集。
**证明3:** 对于每一个等价类 $st$,最长的子串为 $longest$,最短的为 $shortest$,若 $shortest\leq |s|\leq longest$,则 $s$ 也是属于该等价类。
### 3. 构造方法
本人能力不够,也是云里雾里,只能帮到这了( ~~逃~~ )。
自己去理解吧。
```cpp
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
const int N=2e6+10;
struct Node{
int len,fa;
int ch[26];
}tr[N];
char str[N];
ll f[N],ans;
int h[N],e[N],ne[N],idx,last=1,tot=1;
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void extend(int c)
{
int p=last;int np=++tot;
f[tot]=1;
tr[np].len=tr[p].len+1;
for (;p&&!tr[p].ch[c];p=tr[p].fa) tr[p].ch[c]=np;
last=tot;
if (!p) tr[np].fa=1;
else{
int q=tr[p].ch[c];
if (tr[q].len==tr[p].len+1) tr[np].fa=q;
else{
int nq=++tot;
tr[nq]=tr[q],tr[nq].len=tr[p].len+1;
tr[q].fa=tr[np].fa=nq;
for (;p&&tr[p].ch[c]==q;p=tr[p].fa) tr[p].ch[c]=nq;
}
}
}
void dfs(int x)
{
for (int i=h[x];~i;i=ne[i])
{
dfs(e[i]);f[x]+=f[e[i]];
}
if (f[x]>1) ans=max(ans,f[x]*tr[x].len);
}
int main()
{
scanf("%s",str);
for (int i=0;str[i];++i) extend(str[i]-'a');
memset(h,-1,sizeof h);
for (int i=2;i<=tot;++i) add(tr[i].fa,i);
dfs(1);
cout<<ans<<endl;
return 0;
}
```
### 4. 例题
#### T1:玄武密码
[题目传送门 Luogu](https://www.luogu.com.cn/problem/P5231)
[题目传送门 AcWing](https://www.acwing.com/problem/content/1285/)
比模板还要裸。
```cpp
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=2e7+10;
struct Node{
int fa,len;
int ch[4];
}tr[N];
int last=1,tot=1;
char str[N];
void extend(int c)
{
int p=last,np=last=++tot;
tr[np].len=tr[p].len+1;
for (;p&&!tr[p].ch[c];p=tr[p].fa) tr[p].ch[c]=np;
if (!p) tr[np].fa=1;
else{
int q=tr[p].ch[c];
if (tr[q].len==tr[p].len+1) tr[np].fa=q;
else{
int nq=++tot;
tr[nq]=tr[q],tr[nq].len=tr[p].len+1;
tr[q].fa=tr[np].fa=nq;
for (;p&&tr[p].ch[c]==q;p=tr[p].fa) tr[p].ch[c]=nq;
}
}
}
int get(char c)
{
switch (c)
{
case 'W':return 0;
case 'N':return 1;
case 'E':return 2;
case 'S':return 3;
}
}
void dfs()
{
int p=1,i;
for (i=0;str[i]&&tr[p].ch[get(str[i])];++i) p=tr[p].ch[get(str[i])];
printf("%d\n",i);
return;
}
int main()
{
int n,m;
cin>>n>>m;
scanf("%s",str);
for (int i=0;i<n;++i) extend(get(str[i]));
while (m--)
{
scanf("%s",str);
dfs();
}
return 0;
}
```
#### T2:最长公共子串
[题目传送门 AcWing](https://www.acwing.com/problem/content/2813/)
[题目传送门 LOJ](https://loj.ac/p/171)
将第一个建后缀自动机,和后面的进行比较即可。
注意要标记回传,否则更新可能不及时。
注意其中的最大最小:
1. 每一次走到一个节点时,应该和当前的取最大值。
2. 不同的字符串之间,一个节点的值应取最小值。
3. 最后得出答案时,应该把不同的节点所保存的值取最大输出。
```cpp
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e6+10,INF=0x3f3f3f3f;
struct Node{
int len,fa;
int ch[26];
}tr[2*N];
int last=1,tot=1,now[N],ans[N];
char str[N];
int h[N],e[N],ne[N],idx;
void extend(int c)
{
int p=last,np=last=++tot;
tr[np].len=tr[p].len+1;
for (;p&&!tr[p].ch[c];p=tr[p].fa) tr[p].ch[c]=np;
if (!p) tr[np].fa=1;
else{
int q=tr[p].ch[c];
if (tr[q].len==tr[p].len+1) tr[np].fa=q;
else{
int nq=++tot;
tr[nq]=tr[q],tr[nq].len=tr[p].len+1;
tr[q].fa=tr[np].fa=nq;
for (;p&&tr[p].ch[c]==q;p=tr[p].fa) tr[p].ch[c]=nq;
}
}
}
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void tree(int x)
{
for (int i=h[x];~i;i=ne[i])
{
tree(e[i]);
now[x]=max(now[x],now[e[i]]);
}
}
void calc()
{
int i,p=1,t=0;
memset(now,0,sizeof now);
for (i=0;str[i];++i)
{
int c=str[i]-'a';
while (p>1&&!tr[p].ch[c]) p=tr[p].fa,t=tr[p].len;
if (tr[p].ch[c]) p=tr[p].ch[c],t++;
now[p]=max(now[p],t);
}
tree(1);
for (int i=1;i<=tot;++i) ans[i]=min(ans[i],now[i]);
}
int main()
{
int n;
cin>>n;n--;
scanf("%s",str);memset(h,-1,sizeof h);
for (int i=0;str[i];++i) extend(str[i]-'a');
for (int i=1;i<=tot;++i) ans[i]=tr[i].len;
for (int i=2;i<=tot;++i) add(tr[i].fa,i);
while (n--)
{
scanf("%s",str);
calc();
}
int finalres=0;
for (int i=1;i<=tot;++i) finalres=max(finalres,ans[i]);
cout<<finalres<<endl;
}
```