P3808
【模板】AC自动机(简单版)
AC 自动机似乎就是一个 Trie 上的 KMP,用来处理多模式串的匹配问题。
我们先用模式串构建 Trie 树,然后求出每个节点的
那么构建 Trie 树的方法需要 BFS,先处理好浅层的
先给出这一段代码:
void AC_pre() {
queue<ll> q;
Next[0]=0;//初始化最开始的几个节点
for(ll c=0;c<26;c++) {
ll u=trie[0][c];
if(u) {
Next[u]=0;q.push(u);lst[u]=0;
//BFS预处理,lst 的作用在后面说明
}
}
while(!q.empty()) {
ll h=q.front();q.pop();
for(ll c=0;c<26;c++) {
ll u=trie[h][c];
if(!u) {
trie[h][c]=trie[Next[h]][c];//一个补边的优化,在后面解释
continue;
}
q.push(u);
ll v=Next[h];
while(v&&!trie[v][c]) v=Next[v];//寻找 h 的某个子节点 u 的 Next
Next[u]=trie[v][c];
lst[u]=cnt[Next[u]]?Next[u]:lst[Next[u]];//lst 的作用之后再说
}
}
}
我们这里给出了一个
至于为什么需要添加这个数组,在某蓝书中差不多是这样解释的:我们单纯的找到一个词汇,显然很有可能不会考虑以改词的后缀为一个完整的词汇的统计,所以只有加上这个才能完全统计。我大概看了好一会才明白。据蓝书的解释,
然后呢,这个
然后解释一下那个补边的优化。这个加边的意思其实就是在失配的时候跳
AC 自动机的主要过程就是建 Trie,预处理
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define ll long long
using namespace std;
const ll N=1e6;
ll n,ans,tot;
ll cnt[N+5],trie[N+5][26],Next[N+5],lst[N+5];
bool vis[N+5];
char tmp[N+5],s[N+5];
void ins(char *str) {
ll len=strlen(str),p=0;
for(ll k=0;k<len;k++) {
ll ch=str[k]-'a';
if(!trie[p][ch]) trie[p][ch]=++tot;
p=trie[p][ch];
}
cnt[p]++;
}
void count(ll j) {
if(j&&!vis[j]) {
ans+=cnt[j];vis[j]=1;
count(lst[j]);
}
}
void AC_pre() {
queue<ll> q;
Next[0]=0;
for(ll c=0;c<26;c++) {
ll u=trie[0][c];
if(u) {
Next[u]=0;q.push(u);lst[u]=0;
}
}
while(!q.empty()) {
ll h=q.front();q.pop();
for(ll c=0;c<26;c++) {
ll u=trie[h][c];
if(!u) {
trie[h][c]=trie[Next[h]][c];
continue;
}
q.push(u);
ll v=Next[h];
while(v&&!trie[v][c]) v=Next[v];
Next[u]=trie[v][c];
lst[u]=cnt[Next[u]]?Next[u]:lst[Next[u]];
}
}
}
void AC(char *str) {
ll len=strlen(str);
ll j=0;
for(ll i=0;i<len;i++) {
ll c=str[i]-'a';
// printf("In function AC:");
// putchar(c+'a');putchar('\n');
j=trie[j][c];
// printf("j=%lld\n",j);
if(cnt[j]) count(j);
else {
if(lst[j]) count(lst[j]);
}
}
}
inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}
void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {
do{buf[++len]=x%10+48;x/=10;}while(x);
}
else {
putchar('-');
do{buf[++len]=x%10+48;x/=10;}while(x);
}
while(len>=0) putchar(buf[len--]);
}
int main() {
n=read();
for(ll i=1;i<=n;i++) {
scanf("%s",tmp);
ins(tmp);
}
AC_pre();
scanf("%s",s);
AC(s);
write(ans);
return 0;
}