AC自动机
jiazhaopeng · · 个人记录
AC自动机适用于在一个长串中查询多个小串的匹配问题。
P3808 【模板】AC自动机(简单版)
在一个长串中查询多个小串一共出现了几次
名称解释
关键语句
/* build() */
if (trie[cur][i]) fail[trie[cur][i]] = trie[fail[cur]][i], que[++rear] = trie[cur][i];
else trie[cur][i] = trie[fail[cur]][i];
...
/* query() */
for (register int id = np; id && ~cnt[id]; id = fail[id])
ans += cnt[id], cnt[id] = -1;
注意
与SAM不同,AC自动机的根节点编号为0,这样更方便些(如build时少出些锅)
AC自动机的真实节点个数为
AC自动机DP
仍然考虑像上一道题那样模拟AC自动机的运转过程。那么那个大串应该在匹配过程中到达过(或到达点的fail中)至少一个终止节点。
感觉一到达终止节点就停止,似乎接下来还有一段长度,不太好做。于是可以考虑容斥,转而开始求没有被任意终止节点接受的方案数。是不是感觉和上一道题有点像?只不过变成计数问题了。
设定
最后答案就是总方案 减掉 所有非接受节点的 f 。
可以滚动数组优化。如果
int f[2][NN], ans;
inline void dp() {
f[0][0] = 1;
bool type = 0;
for (register int i = 1; i <= m; ++i) {
type ^= 1;
memset(f[type], 0, sizeof(f[type]));
for (register int p = 0; p <= ttot; ++p) {
if (fg[p]) continue;
for (register int c = 0; c < 26; ++c) {
f[type][trie[p][c]] = (f[type][trie[p][c]] + f[type ^ 1][p]) % P;
}
}
}
for (register int p = 0; p <= ttot; ++p) {
if (fg[p]) continue;
ans += f[type][p];
}
}
一些练习题
- P3966 [TJOI2013]单词
板子题
- P2414 [NOI2011]阿狸的打字机
Trie树题,不过是用到了AC自动机里面的类似fail指针的东西以及fail树,又结合了树状数组(Trie链与fail子树求交集),有一定难度。
模板
int trie[N][26], fail[N], ttot, cnt[N];
inline void insert(char *s) {
int len = strlen(s + 1);
int np = 0;
for (register int i = 1; i <= len; ++i) {
int ch = s[i] - 'a';
if (!trie[np][ch]) trie[np][ch] = ++ttot;
np = trie[np][ch];
}
cnt[np]++;
}
int que[N], front, rear;
inline void build() {
for (register int i = 0; i < 26; ++i) if (trie[0][i]) que[++rear] = trie[0][i];
while (front < rear) {
int cur = que[++front];
for (register int i = 0; i < 26; ++i) {
if (trie[cur][i]) fail[trie[cur][i]] = trie[fail[cur]][i], que[++rear] = trie[cur][i];
else trie[cur][i] = trie[fail[cur]][i];
}
}
}
模板(二进制分组,Trie树的合并,AC自动机的构建,结构体封装)
//CF710F String Set Queries
struct ACAM {
int trie[N][26], fail[N], path[N][26], ttot, rt[N], stk[N], num[N], sum[N], stop;
inline void init() {
memset(trie, 0, sizeof(trie));
memset(path, 0, sizeof(path));
memset(fail, 0, sizeof(fail));
memset(rt, 0, sizeof(rt));
memset(num, 0, sizeof(num));
memset(sum, 0, sizeof(sum));
ttot = 0, stop = 0;
}
inline void ins() {
rt[++stop] = ++ttot; stk[stop] = 1;
int n = str.size();
int p = rt[stop];
for (register int i = 0; i < n; ++i) {
int c = str[i] - 'a';
if (!trie[p][c]) trie[p][c] = ++ttot;
p = trie[p][c];
}
//++num[p];
num[p] = 1;
}
int que[N], front, rear;
inline void build(int Rt) {
front = rear = 0;
for (register int i = 0; i < 26; ++i)
if (trie[Rt][i])
que[++rear] = trie[Rt][i], path[Rt][i] = trie[Rt][i], fail[trie[Rt][i]] = Rt;
else path[Rt][i] = Rt;
while (front < rear) {
int p = que[++front];
for (register int c = 0; c < 26; ++c) {
if (trie[p][c]) {
path[p][c] = trie[p][c], fail[path[p][c]] = path[fail[p]][c];
que[++rear] = path[p][c];
} else {
path[p][c] = path[fail[p]][c];
}
}
sum[p] = num[p] + sum[fail[p]];
}
}
int merge(int lcur, int rcur) {
if (!lcur || !rcur) return lcur ^ rcur;
num[lcur] += num[rcur];
for (register int c = 0; c < 26; ++c)
trie[lcur][c] = merge(trie[lcur][c], trie[rcur][c]);
return lcur;
}
inline void add() {
ins();
while (stk[stop] == stk[stop - 1])
rt[stop - 1] = merge(rt[stop - 1], rt[stop]), stk[stop - 1] += stk[stop], stop--;
build(rt[stop]);
}
inline int query() {
int len = str.size(), res = 0;
for (register int t = stop; t; --t) {
int p = rt[t];
for (register int i = 0; i < len; ++i) {
int c = str[i] - 'a';
p = path[p][c];
res += sum[p];
}
}
return res;
}
}A, B;