字符串:后缀自动机(SAM)

· · 个人记录

后缀自动机(SAM)

1. 主要思想

可以类比于 Trie,但它是一个有向无环图。

可以用来存储一个串的所有子串。

2. 性质

从起点出发,从普通边,可以得到所有子串。

同时,所有边和点的数量在 O(n) 级别。

终点就是包含后缀的点,同时每一个点对应的是不止一个子串。

但所有每一个点对应的子串满足:按长度降序排序后,我们发现每一个串都是前面一个串的后缀。

还有一种边:Link/Father

这种边构成一棵树。

将一个点所对应的最短的子串的首字母去掉,得到的子串所对应的节点,再由原来节点指向这个节点。

首先定义 endpos(s) 为子串 s 在原串所有出现的位置(尾字母)下标集合。

如果有两个子串的 endpos 相同,则我们可以将其看为一个等价类。

SAM 的状态与所有的等价类一一对应。

证明1:如果 |s1|\leq|s2|,则 s1s2 的后缀当且仅当 endpos(s1) \supseteq endpos(s2)

这两个易证。 **证明2:**两个不同子串的 $endpos$ 要么包含要么无交集。 易得,如果不为交集,那么必有一个 $pos$ 使两个都满足,则必有一个是另一个的子集。 **证明3:** 对于每一个等价类 $st$,最长的子串为 $longest$,最短的为 $shortest$,若 $shortest\leq |s|\leq longest$,则 $s$ 也是属于该等价类。 ### 3. 构造方法 本人能力不够,也是云里雾里,只能帮到这了( ~~逃~~ )。 自己去理解吧。 ```cpp #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; const int N=2e6+10; struct Node{ int len,fa; int ch[26]; }tr[N]; char str[N]; ll f[N],ans; int h[N],e[N],ne[N],idx,last=1,tot=1; void add(int a,int b) { e[idx]=b,ne[idx]=h[a],h[a]=idx++; } void extend(int c) { int p=last;int np=++tot; f[tot]=1; tr[np].len=tr[p].len+1; for (;p&&!tr[p].ch[c];p=tr[p].fa) tr[p].ch[c]=np; last=tot; if (!p) tr[np].fa=1; else{ int q=tr[p].ch[c]; if (tr[q].len==tr[p].len+1) tr[np].fa=q; else{ int nq=++tot; tr[nq]=tr[q],tr[nq].len=tr[p].len+1; tr[q].fa=tr[np].fa=nq; for (;p&&tr[p].ch[c]==q;p=tr[p].fa) tr[p].ch[c]=nq; } } } void dfs(int x) { for (int i=h[x];~i;i=ne[i]) { dfs(e[i]);f[x]+=f[e[i]]; } if (f[x]>1) ans=max(ans,f[x]*tr[x].len); } int main() { scanf("%s",str); for (int i=0;str[i];++i) extend(str[i]-'a'); memset(h,-1,sizeof h); for (int i=2;i<=tot;++i) add(tr[i].fa,i); dfs(1); cout<<ans<<endl; return 0; } ``` ### 4. 例题 #### T1:玄武密码 [题目传送门 Luogu](https://www.luogu.com.cn/problem/P5231) [题目传送门 AcWing](https://www.acwing.com/problem/content/1285/) 比模板还要裸。 ```cpp #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=2e7+10; struct Node{ int fa,len; int ch[4]; }tr[N]; int last=1,tot=1; char str[N]; void extend(int c) { int p=last,np=last=++tot; tr[np].len=tr[p].len+1; for (;p&&!tr[p].ch[c];p=tr[p].fa) tr[p].ch[c]=np; if (!p) tr[np].fa=1; else{ int q=tr[p].ch[c]; if (tr[q].len==tr[p].len+1) tr[np].fa=q; else{ int nq=++tot; tr[nq]=tr[q],tr[nq].len=tr[p].len+1; tr[q].fa=tr[np].fa=nq; for (;p&&tr[p].ch[c]==q;p=tr[p].fa) tr[p].ch[c]=nq; } } } int get(char c) { switch (c) { case 'W':return 0; case 'N':return 1; case 'E':return 2; case 'S':return 3; } } void dfs() { int p=1,i; for (i=0;str[i]&&tr[p].ch[get(str[i])];++i) p=tr[p].ch[get(str[i])]; printf("%d\n",i); return; } int main() { int n,m; cin>>n>>m; scanf("%s",str); for (int i=0;i<n;++i) extend(get(str[i])); while (m--) { scanf("%s",str); dfs(); } return 0; } ``` #### T2:最长公共子串 [题目传送门 AcWing](https://www.acwing.com/problem/content/2813/) [题目传送门 LOJ](https://loj.ac/p/171) 将第一个建后缀自动机,和后面的进行比较即可。 注意要标记回传,否则更新可能不及时。 注意其中的最大最小: 1. 每一次走到一个节点时,应该和当前的取最大值。 2. 不同的字符串之间,一个节点的值应取最小值。 3. 最后得出答案时,应该把不同的节点所保存的值取最大输出。 ```cpp #include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=1e6+10,INF=0x3f3f3f3f; struct Node{ int len,fa; int ch[26]; }tr[2*N]; int last=1,tot=1,now[N],ans[N]; char str[N]; int h[N],e[N],ne[N],idx; void extend(int c) { int p=last,np=last=++tot; tr[np].len=tr[p].len+1; for (;p&&!tr[p].ch[c];p=tr[p].fa) tr[p].ch[c]=np; if (!p) tr[np].fa=1; else{ int q=tr[p].ch[c]; if (tr[q].len==tr[p].len+1) tr[np].fa=q; else{ int nq=++tot; tr[nq]=tr[q],tr[nq].len=tr[p].len+1; tr[q].fa=tr[np].fa=nq; for (;p&&tr[p].ch[c]==q;p=tr[p].fa) tr[p].ch[c]=nq; } } } void add(int a,int b) { e[idx]=b,ne[idx]=h[a],h[a]=idx++; } void tree(int x) { for (int i=h[x];~i;i=ne[i]) { tree(e[i]); now[x]=max(now[x],now[e[i]]); } } void calc() { int i,p=1,t=0; memset(now,0,sizeof now); for (i=0;str[i];++i) { int c=str[i]-'a'; while (p>1&&!tr[p].ch[c]) p=tr[p].fa,t=tr[p].len; if (tr[p].ch[c]) p=tr[p].ch[c],t++; now[p]=max(now[p],t); } tree(1); for (int i=1;i<=tot;++i) ans[i]=min(ans[i],now[i]); } int main() { int n; cin>>n;n--; scanf("%s",str);memset(h,-1,sizeof h); for (int i=0;str[i];++i) extend(str[i]-'a'); for (int i=1;i<=tot;++i) ans[i]=tr[i].len; for (int i=2;i<=tot;++i) add(tr[i].fa,i); while (n--) { scanf("%s",str); calc(); } int finalres=0; for (int i=1;i<=tot;++i) finalres=max(finalres,ans[i]); cout<<finalres<<endl; } ```