P3808

· · 个人记录

【模板】AC自动机(简单版)

AC 自动机似乎就是一个 Trie 上的 KMP,用来处理多模式串的匹配问题。

我们先用模式串构建 Trie 树,然后求出每个节点的 Next,但区别是这里的公共前缀后缀,不一定就是该节点所在模式串的前缀后缀了。也就是说,AC 自动机的 Trie 树上的这些个 Next,不一定就是返祖边,还有可能是横向边。

那么构建 Trie 树的方法需要 BFS,先处理好浅层的 Next,然后扩展其子节点的 Next。显然子节点的后缀是需要其父节点支撑的,所以尽管有横向边我们仍然可以通过不断查找 Next 来确定子节点的 Next 值。

先给出这一段代码:

void AC_pre() {
    queue<ll> q;
    Next[0]=0;//初始化最开始的几个节点
    for(ll c=0;c<26;c++) {
        ll u=trie[0][c];
        if(u) {
            Next[u]=0;q.push(u);lst[u]=0;
            //BFS预处理,lst 的作用在后面说明
        }
    }
    while(!q.empty()) {
        ll h=q.front();q.pop();
        for(ll c=0;c<26;c++) {
            ll u=trie[h][c];
            if(!u) {
                trie[h][c]=trie[Next[h]][c];//一个补边的优化,在后面解释
                continue;
            }
            q.push(u);
            ll v=Next[h];
            while(v&&!trie[v][c]) v=Next[v];//寻找 h 的某个子节点 u 的 Next
            Next[u]=trie[v][c];
            lst[u]=cnt[Next[u]]?Next[u]:lst[Next[u]];//lst 的作用之后再说
        }
    }
}

我们这里给出了一个 lst,这个数组表示的也是公共前缀后缀的位置,但与 Next 不同的地方在于这个指向的位置是一个完整的单词。

至于为什么需要添加这个数组,在某蓝书中差不多是这样解释的:我们单纯的找到一个词汇,显然很有可能不会考虑以改词的后缀为一个完整的词汇的统计,所以只有加上这个才能完全统计。我大概看了好一会才明白。据蓝书的解释,lst 在正规文献里叫做后缀链接(suffix link)。

然后呢,这个 lst 的计算方法也很明确,我们仍然是采取跳 Next 的方法,如果说正好跳一次到了,那 lst 的位置就确定了,反之还需要再跳吗?显然不需要,因为 lst_{Next} 已经被计算过了,所以我们直接使用即可。

然后解释一下那个补边的优化。这个加边的意思其实就是在失配的时候跳 Next 找到可以匹配的节点,只不过我们在预处理的阶段就事先完成了这个工作。这样的话在 AC 自动机匹配的时候我们就可以完全省略通过跳 Next 重新匹配的过程了。之所以能这样用还是因为 AC 自动机是以 Trie 树为基础的,正好这些空着的子节点不用白不用,我们用它来优化是非常合适的。

AC 自动机的主要过程就是建 Trie,预处理 Next,以及最后的多模式匹配。总的时间复杂度为 O(km+nm),其中 k 为模式串个数,m 为模式串平均长度,n 为文本串长度。

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define ll long long
using namespace std;

const ll N=1e6;

ll n,ans,tot;

ll cnt[N+5],trie[N+5][26],Next[N+5],lst[N+5];

bool vis[N+5];

char tmp[N+5],s[N+5];

void ins(char *str) {
    ll len=strlen(str),p=0;
    for(ll k=0;k<len;k++) {
        ll ch=str[k]-'a';
        if(!trie[p][ch]) trie[p][ch]=++tot;
        p=trie[p][ch];
    }
    cnt[p]++;
}

void count(ll j) {
    if(j&&!vis[j]) {
        ans+=cnt[j];vis[j]=1;
        count(lst[j]);
    }
}

void AC_pre() {
    queue<ll> q;
    Next[0]=0;
    for(ll c=0;c<26;c++) {
        ll u=trie[0][c];
        if(u) {
            Next[u]=0;q.push(u);lst[u]=0;
        }
    }
    while(!q.empty()) {
        ll h=q.front();q.pop();
        for(ll c=0;c<26;c++) {
            ll u=trie[h][c];
            if(!u) {
                trie[h][c]=trie[Next[h]][c];
                continue;
            }
            q.push(u);
            ll v=Next[h];
            while(v&&!trie[v][c]) v=Next[v];
            Next[u]=trie[v][c];
            lst[u]=cnt[Next[u]]?Next[u]:lst[Next[u]];
        }
    }
}

void AC(char *str) {
    ll len=strlen(str);
    ll j=0;
    for(ll i=0;i<len;i++) {
        ll c=str[i]-'a';
//      printf("In function AC:");
//      putchar(c+'a');putchar('\n');
        j=trie[j][c];
//      printf("j=%lld\n",j);
        if(cnt[j]) count(j);
        else {
            if(lst[j]) count(lst[j]);
        }
    }
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

int main() {

    n=read();

    for(ll i=1;i<=n;i++) {
        scanf("%s",tmp);
        ins(tmp);
    }
    AC_pre();
    scanf("%s",s);
    AC(s);
    write(ans);

    return 0;
}