AC自动机

· · 个人记录

AC自动机是用来求给定多个模式串,求有多少个在一个文本串中出现过。本质上是在字典树上跑KMP。

AC自动机中有一个失配指针fail类似于KMP中的next数组,都是求出在当前这个字符(结点)不满足匹配的情况下,模式串会跳转到哪个位置继续匹配。

我们现在来看一个用目标字符串集合{abd,abdk, abchijn, chnit, ijabdf, ijaij}构造出来的AC自动机

字典树建好后,就可以开始求fail数组了,fail数组代表的意思是从根节点到当前结点的路径中,所有的后缀,与整个字典树中的前缀最长的公共部分

所以在当前结点失配后,模式串会根据fail数组来到与这个后缀最长公共的前缀的结点,例如:

当前的文本串为 abchn 如图进行匹配,当匹配完h后,将会拿文本串的n与字典树上的i这个结点比较,发现不符合,那么将会跳转到h这个结点的fail值,如图可以发现,跳转到了chn这条路径上,此时匹配完成

我们可以发现如果当当前结点失配后,那么模式串会跳转到它父亲结点匹配的地方 如下代码:

for(int i=0;i<26;i++)
        {
            int c=tr[t][i];
            if(!c) continue;
            int j=ne[t];
            while(j&&!tr[j][i]) j=ne[j];
            if(tr[j][i]) j=tr[j][i];
            ne[c]=j;
            q.push(c);
        }

u代表从这个往下匹配结点,每匹配完一层,一定保证了以上所有层的失配指针完成了指向,如果当u下面a+'i'这个字母存在的话,失配后就会跳转到u结点的失配结点v,且v的下一个一定是a+'i'这个字母 可以借助下图来理解:

如图可知,我们当前正在进行匹配的结点是x1这个结点,它的儿子y这个结点存在,那我们首先会跳到x1的失配结点x2处,观察x2的所有结点是否有y这个字母,发现没有,那我们接着再跳到x2这个结点的失配结点处,就是x3结点处.....依次类推,直到找到某一个前缀,它的儿子里面有y这个结点

此时有一个问题,就是按照我们这么不断往上跳的话,那么fail指针其实指的是u这个结点父亲结点失配后会跳向的地址但是跳向的新结点v的儿子结点不一定等于u这个结点,所以我们需要不断的往上一层跳来找到满足的地址。

再仔细回想一下我们的fail指针的定义,在当前结点失配后,这条路径上所有的后缀与字典树上前缀相同的部分的最长公共部分实际上我们是直接想跳到一个结点v,这个结点v的儿子与原来结点的儿子相同

如下代码:

for(int i=0; i<26; i++) {
            if(tree[u][i]) {
                fail[tree[u][i]]=tree[fail[u]][i];
                q.push(tree[u][i]);
            }
            else tree[u][i]=tree[fail[u]][i];
        }

u是当前操作的结点,假如儿子a+'i'这个结点存在,直接把它的失配结点指向它父亲结点的失配后指向的结点的下个a+'i'这个字母的编号

否则如果u的下面没有a+'i'这个字母(但是在用文本串跑字典树的时候可能会有这个字母),就把下面这个结点的编号直接改为下一个字母存在a+'i'的前缀的编号

匹配的代码相对起来好理解一些:

int query(string s)
{
    int ans=0,now=0;
    for(int i=0; i<s.size(); i++) {
        now=tree[now][s[i]-'a'];
        for(int j=now; j&&cnt[j]!=-1; j=fail[j]) {
            ans+=cnt[j];
            cnt[j]=-1;
        }
    }
    return ans;
}

例题:AC自动机模板

代码:

#include<bits/stdc++.h>
#define int long long
#define N 1000004
using namespace std;
int rt,tree[N][30],fail[N],cnt[N];
void insert(string s) 
{
    int root=0;
    for(int i=0; i<s.size(); i++) {
        int a=s[i]-'a';
        if(!tree[root][a]) {
            tree[root][a]=++rt;
        }
        root=tree[root][a];
    }
    cnt[root]++;
}
void bfs()
{
    queue<int> q;
    for(int i=0; i<26; i++) {
        if(tree[0][i]) {
            q.push(tree[0][i]);
            fail[tree[0][i]]=0;
        }
    }
    while(!q.empty())
    {
        int u=q.front();
        q.pop();
        for(int i=0; i<26; i++) {
            if(tree[u][i]) {
                fail[tree[u][i]]=tree[fail[u]][i];
                q.push(tree[u][i]);
            }
            else tree[u][i]=tree[fail[u]][i];
        }
    } 
}
int query(string s)
{
    int ans=0,now=0;
    for(int i=0; i<s.size(); i++) {
        now=tree[now][s[i]-'a'];
    //  cout<<now<<endl;
        for(int j=now; j&&cnt[j]!=-1; j=fail[j]) {
            ans+=cnt[j];
            cnt[j]=-1;
        }
    }
    return ans;
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n;
    cin>>n;
    for(int i=1; i<=n; i++) {
        string s;
        cin>>s;
        insert(s);
    }
    string s;
    bfs();
    cin>>s;
    cout<<query(s);
    return 0;
}