AC自动机

· · 个人记录

AC自动机适用于在一个长串中查询多个小串的匹配问题

P3808 【模板】AC自动机(简单版)

在一个长串中查询多个小串一共出现了几次

名称解释

关键语句

/* build() */
if (trie[cur][i])   fail[trie[cur][i]] = trie[fail[cur]][i], que[++rear] = trie[cur][i];
else    trie[cur][i] = trie[fail[cur]][i];
...
/* query() */
for (register int id = np; id && ~cnt[id]; id = fail[id])
    ans += cnt[id], cnt[id] = -1;

注意

与SAM不同,AC自动机的根节点编号为0,这样更方便些(如build时少出些锅)

AC自动机的真实节点个数为 ttot + 1,还有一个0节点

```cpp int trie[N][26], fail[N], ttot, cnt[N]; inline void insert(char *s) { int len = strlen(s + 1); int np = 0; for (register int i = 1; i <= len; ++i) { int ch = s[i] - 'a'; if (!trie[np][ch]) trie[np][ch] = ++ttot; np = trie[np][ch]; } cnt[np]++; } int que[N], front, rear; inline void build() { for (register int i = 0; i < 26; ++i) if (trie[0][i]) que[++rear] = trie[0][i]; while (front < rear) { int cur = que[++front]; for (register int i = 0; i < 26; ++i) { if (trie[cur][i]) fail[trie[cur][i]] = trie[fail[cur]][i], que[++rear] = trie[cur][i]; else trie[cur][i] = trie[fail[cur]][i]; } } } int ans; inline void query(char *s) { int len = strlen(s + 1); int np = 0; for (register int i = 1; i <= len; ++i) { int ch = s[i] - 'a'; np = trie[np][ch]; for (register int id = np; id && ~cnt[id]; id = fail[id]) ans += cnt[id], cnt[id] = -1; } } ``` ## [P5357 【模板】AC自动机(二次加强版)](https://www.luogu.com.cn/problem/P5357) 在一个长串中查询多个小串**各自**出现了几次 --- 自己想出来一种极其美妙的方法:与上面的题类似,只不过这次不打 -1 标记了,这样就能统计到所有模式串的出现次数了。 ```cpp inline void query(char *s) { register int np = 0, len = strlen(s + 1); for (register int i = 1; i <= len; ++i) { register int ch = s[i] - 'a'; np = trie[np][ch]; for (register int j = np; j; j = fail[j]) if (fg[j]) cnt[j]++; } } ``` 最后查cnt即可。 然而 TLE 了。 --- 发现 fail 指针和 SAM 的fa指针(suffix links)类似,最终 fail 指针也会形成一个类似 Parents Tree 的东西(每个点都只有一个父亲,且无环),不妨称之为 fail tree(我自己起的名)。 每次扫到大串的一个字符,就会执行 从 np 到根的路径 这么多次的for,因此TLE了。即TLE的原因在下面: ```cpp for (register int j = np; j; j = fail[j]) if (fg[j]) cnt[j]++; ``` 那么我们干脆执行**树上差分**就行了。 ```cpp inline void topu() { front = rear = 0; for (register int i = 1; i <= ttot; ++i) if (!d[i]) que[++rear] = i; while (front < rear) { register int cur = que[++front]; tag[fail[cur]] += tag[cur]; if ((--d[fail[cur]]) == 0) que[++rear] = fail[cur]; } } ``` ## [P2444 [POI2000]病毒](https://www.luogu.com.cn/problem/P2444) 给一堆小串,问是否存在一个长串的长度为无限长,且长串**不含**任何小串。 --- 题目的意思就是要我们拿一个长串在Trie上匹配,**永远跑不到非法节点**,即存在环。 要注意AC自动机是**有限状态自动机**,即它能够识别一些字符串,即那些小串。它的**正式的边为Trie[][]引出的边,包括原Trie上的和后加的**。拿个长串在这些边上面跑,跑到终止节点就被识别。并不是所有节点都是接受节点。 因此查找自动机里面是否有环。 注意,非法节点不止是终止节点,还有失配节点是终止节点的点(如:ab,ababcbd,Trie上的ab节点和abab节点都非法) $Code$ : [my record](https://www.luogu.com.cn/record/31220091) 题解:[题解 P2444 【[POI2000]病毒】](https://www.luogu.com.cn/blog/wangwangwangwangwang/solution-p2444) ## [P4052 [JSOI2007]文本生成器](https://www.luogu.com.cn/problem/P4052) 给定n个小串,问有多少个长为m的大串,其中大串包含至少一个小串。 $n <= 60, \sum{len} <= 6000, m <= 100, P = 10007

AC自动机DP

仍然考虑像上一道题那样模拟AC自动机的运转过程。那么那个大串应该在匹配过程中到达过(或到达点的fail中)至少一个终止节点。

感觉一到达终止节点就停止,似乎接下来还有一段长度,不太好做。于是可以考虑容斥,转而开始求没有被任意终止节点接受的方案数。是不是感觉和上一道题有点像?只不过变成计数问题了。

设定 f[i][p] 表示大串已经匹配到了第 i 个字符,目前在 p 点(仍未被接受)的方案数。转移显而易见。

最后答案就是总方案 减掉 所有非接受节点的 f 。

可以滚动数组优化。如果 \sum len 小一点, m 大一点,我们还可以用矩乘优化。

int f[2][NN], ans;
inline void dp() {
    f[0][0] = 1;
    bool type = 0;
    for (register int i = 1; i <= m; ++i) {
        type ^= 1;
        memset(f[type], 0, sizeof(f[type]));
        for (register int p = 0; p <= ttot; ++p) {
            if (fg[p])  continue;
            for (register int c = 0; c < 26; ++c) {
                f[type][trie[p][c]] = (f[type][trie[p][c]] + f[type ^ 1][p]) % P;
            }
        }
    }
    for (register int p = 0; p <= ttot; ++p) {
        if (fg[p])  continue;
        ans += f[type][p];
    }
}

一些练习题

板子题

Trie树题,不过是用到了AC自动机里面的类似fail指针的东西以及fail树,又结合了树状数组(Trie链与fail子树求交集),有一定难度。

模板

int trie[N][26], fail[N], ttot, cnt[N];
inline void insert(char *s) {
    int len = strlen(s + 1);
    int np = 0;
    for (register int i = 1; i <= len; ++i) {
        int ch = s[i] - 'a';
        if (!trie[np][ch]) trie[np][ch] = ++ttot;
        np = trie[np][ch];
    }
    cnt[np]++;
}

int que[N], front, rear;
inline void build() {
    for (register int i = 0; i < 26; ++i)   if (trie[0][i]) que[++rear] = trie[0][i];
    while (front < rear) {
        int cur = que[++front];
        for (register int i = 0; i < 26; ++i) {
            if (trie[cur][i])   fail[trie[cur][i]] = trie[fail[cur]][i], que[++rear] = trie[cur][i];
            else    trie[cur][i] = trie[fail[cur]][i];
        }
    }
}

模板(二进制分组,Trie树的合并,AC自动机的构建,结构体封装)

//CF710F String Set Queries
struct ACAM {
    int trie[N][26], fail[N], path[N][26], ttot, rt[N], stk[N], num[N], sum[N], stop;
    inline void init() {
        memset(trie, 0, sizeof(trie));
        memset(path, 0, sizeof(path));
        memset(fail, 0, sizeof(fail));
        memset(rt, 0, sizeof(rt));
        memset(num, 0, sizeof(num));
        memset(sum, 0, sizeof(sum));
        ttot = 0, stop = 0;
    }
    inline void ins() {
        rt[++stop] = ++ttot; stk[stop] = 1;
        int n = str.size();
        int p = rt[stop];
        for (register int i = 0; i < n; ++i) {
            int c = str[i] - 'a';
            if (!trie[p][c])    trie[p][c] = ++ttot;
            p = trie[p][c];
        }
        //++num[p];
        num[p] = 1;
    }
    int que[N], front, rear;
    inline void build(int Rt) {
        front = rear = 0;
        for (register int i = 0; i < 26; ++i)
            if (trie[Rt][i])
                que[++rear] = trie[Rt][i], path[Rt][i] = trie[Rt][i], fail[trie[Rt][i]] = Rt;
            else    path[Rt][i] = Rt;
        while (front < rear) {
            int p = que[++front];
            for (register int c = 0; c < 26; ++c) {
                if (trie[p][c]) {
                    path[p][c] = trie[p][c], fail[path[p][c]] = path[fail[p]][c];
                    que[++rear] = path[p][c];
                } else {
                    path[p][c] = path[fail[p]][c];
                }
            }
            sum[p] = num[p] + sum[fail[p]];
        }
    }
    int merge(int lcur, int rcur) {
        if (!lcur || !rcur) return lcur ^ rcur;
        num[lcur] += num[rcur];
        for (register int c = 0; c < 26; ++c)
            trie[lcur][c] = merge(trie[lcur][c], trie[rcur][c]);
        return lcur;
    }
    inline void add() {
        ins();
        while (stk[stop] == stk[stop - 1])
            rt[stop - 1] = merge(rt[stop - 1], rt[stop]), stk[stop - 1] += stk[stop], stop--;
        build(rt[stop]);
    }
    inline int query() {
        int len = str.size(), res = 0;
        for (register int t = stop; t; --t) {
            int p = rt[t];
            for (register int i = 0; i < len; ++i) {
                int c = str[i] - 'a';
                p = path[p][c];
                res += sum[p];
            }
        }
        return res;
    }
}A, B;