SAM 简记

· · 算法·理论

SAM 是一张 DAG,满足对于 S 中的任意子串 S_{l\sim r},都能从源点出发,令 ilr,不断通过当前节点 u 的字符 S_i 对应出边走向另一节点 v,使当 u 相同但 S_i 不同时 v 也不同,且不陷入空节点。

一个显然的想法是,把 S 的每一个后缀插入一棵字典树中,但时空复杂度均高达 O(n^2),不可取。

如果定义 endpos(T) 表示子串 TS 中出现的位置的右端点所构成的集合,那么 endpos 相等的子串便是等价的了,故可缩成一点。

考虑如下过程:开始有一空串 T,对应节点 p=1endpos(T)=\{1,2,...,n\}。然后往空串前加一字符,得到 26 个新串(不妨设字符集为 \{a,b,...,z\}),对于每个新串 U,如果 US 的子串且 endpos(U) \ne endpos(S),那么新建一节点 q,并认 p 为父亲,接着对 Uq 递归处理,建成一棵 parent tree。

显然,如果 p 有一个儿子,那么必然是失掉了一个前缀,故 |endpos(T)|-1=|endpos(U)|;否则,相当于将 endpos(T) 拆分为若干互不相交且都不为空的子集。综合分析,可得 parent tree 点数为 O(n) 级别。

parent tree 中的边,相当于从一个 endpos 等价类中的最短子串开头删去一字符。而 SAM 中的边,在相当于在一个 endpos 等价类中的最长子串结尾加一字符。parent tree 对 SAM 的建立起到不可或缺的辅助作用。

常用的建图方法是每次往 S 末端插入一个字符,时时维护对应的 SAM。这种方法在线,且均摊 O(n)

先上代码:

int tot=1,lst=1;
void insert(int x){
    int p=lst,np=lst=++tot;
    tr[np].len=tr[p].len+1;
    // Step1
    for(;p&&!tr[p].ch[x];p=tr[p].fa) tr[p].ch[x]=np;
    // Step2
    if(!p) tr[np].fa=1; //Step 3
    else{
        int q=tr[p].ch[x];
        if(tr[q].len==tr[p].len+1) tr[np].fa=q; //Step 4
        else{
            int nq=++tot;tr[nq]=tr[q];
            tr[nq].len=tr[p].len+1;
            tr[np].fa=tr[q].fa=nq;
            for(;p&&tr[p].ch[x]==q;p=tr[p].fa) tr[p].ch[x]=nq; // Step5
        }
    }
}

该流程可分为五步。

方便起见,设加入了字符 c,令加入前的串为旧串,加入后的串为新串,新串长度为 n

一、新建一个节点 np 表示整个新串的所属节点。

二、从长到短遍历旧串的每个后缀,如果加入 c 后得到空节点,那么说明该旧串后缀加上 c 后的新串后缀未在旧串中出现,从而其 endpos\{n\},与整个新串的 endpos 相等,都属于 np

三、如果 p 是空节点,说明 c 未在旧串中出现,因此 np 认源点为父亲。

四、如果 p 不是空节点,且 p 加入 c 后所得节点 q 的最长字符串长度恰为 p 的最长字符串长度加一,即 q 的最短字符串长度。从而 q 中只有一种字符串,且是新串的后缀,从而 npq 为父亲。

五、如果 q 中字符串长度不一,那么 q 中只有长度为 p 的最长字符串长度加一的字符串是新串的后缀,其余字符串的 endpos 不变(否则在第二步应该被遍历)。我们把这个 endpos 中增加了 n 的字符串 V 取出,新建一节点 nq 存放。在 q 中任意字符串和 V 的末尾加上同样的字符,所得的节点自然相同(二者 endpos 的唯一差异为 n,而 n 后不能再插字符)。之后,qnpnq 为父亲。最终,清除 q 的历史遗存,把 p 的所有祖先中原本连向 q 的边都连向 nq

模板题的完整代码:

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5;
typedef long long ll;
int n;
char S[N];
struct Node{
    int ch[26];
    int len,fa,sum;
} tr[N<<1];
int tot=1,lst=1;
void insert(int x){
    int p=lst,np=lst=++tot;
    tr[np].sum++,tr[np].len=tr[p].len+1;
    for(;p&&!tr[p].ch[x];p=tr[p].fa) tr[p].ch[x]=np;
    if(!p) tr[np].fa=1; 
    else{
        int q=tr[p].ch[x];
        if(tr[q].len==tr[p].len+1) tr[np].fa=q; 
        else{
            int nq=++tot;tr[nq]=tr[q],tr[nq].sum=0;
            tr[nq].len=tr[p].len+1;
            tr[np].fa=tr[q].fa=nq;
            for(;p&&tr[p].ch[x]==q;p=tr[p].fa) tr[p].ch[x]=nq; 
        }
    }
}
vector<int> son[N<<1];
void dfs(int u){
    for(int v:son[u]) dfs(v),tr[u].sum+=tr[v].sum;
}
int main(){
    scanf("%s",S+1),n=strlen(S+1);
    for(int i=1;i<=n;i++) insert(S[i]-'a');
    for(int i=2;i<=tot;i++) son[tr[i].fa].push_back(i);
    dfs(1);
    ll ans=0;
    for(int i=1;i<=tot;i++) if(tr[i].sum!=1) ans=max(ans,(ll)tr[i].len*tr[i].sum);
    printf("%lld",ans);
    return 0;
}

SAM 极具巧思,值得仔细玩味。