數據結構-字符串-AC自動機

數據結構-字符串-AC自動機

作用:單文本串多模式串匹配。
前綴知識:trie樹\color{orange}\texttt{trie樹}

AC自動機可以看作是在字典樹上做 KMP,但並不是把 KMP 算法放到樹上來,而是用了一種和 KMP 類似的思想,即在字典樹上匹配文本串的時候如果失配,就跳到 failfail 指針所指的節點,所以學AC自動機沒必要精通 KMP。

拿例題來講:給定 nn 個模式串和 11 個文本串,求有多少個模式串在文本串裏出現過。

那麼要先構造一個字典樹將上述字符串存儲,代碼如下:

class Trie{
public:
	int ch[N][26],mk[N],cnt;
	Trie(){cnt=1;}
	void insert(char*s){
		int n=strlen(s+1),x=1;
		for(int i=1;i<=n;i++){
			int c=s[i]-'a'+1;
			if(!ch[x][c]) ch[x][c]=++cnt;
			x=ch[x][c];
		}
		mk[x]++;//以該節點爲結尾的模式串數
	}
};

比如有這些模式串:
kony\texttt{kony}
wen\texttt{wen}
emm\texttt{emm}
kib\texttt{kib}

那麼構造成的字典樹會長這樣:
演示文稿1.jpg
AC自動機就是在字典樹的基礎上,對於每個節點 xx,增加一個指針 fail[x]fail[x],如上文所述,用來在失配時跳轉指針,這樣如果匹配失敗就不需要回溯了。

如上圖,根節點的子節點,即第一層的的節點 xx,應該有 fail[x]=1fail[x]=1,即如果在第一個字符失配,就從頭開始再找(如下圖,黃色箭頭表示 fail[]fail[])。

acam2.jpg
如果節點 xx 沒有子節點 ch[x][c]ch[x][c],那麼 ch[x][c]=ch[fail[x]][c]ch[x][c]=ch[fail[x]][c],相當於由於沒有該字符子節點而失配(末尾符失配),自動跳轉 failfail(如下圖,紫色箭頭表示 ch[][]ch[][])。

acam3.jpg

爲了簡化問題,後文圖中不加末尾符失配紫色指針。

bfsbfs 序遍歷字典樹。如果節點 xx 有子節點 ch[x][c]ch[x][c],那麼 fail[ch[x][c]]=ch[fail[x]][c]fail[ch[x][c]]=ch[fail[x]][c],即如果失配就跳到最相鄰已經遍歷過的字符爲 cc 的節點中。如果目前還沒有發現字符爲 cc 的節點,就令 fail[ch[x][c]]=1fail[ch[x][c]]=1。如下圖:

acam4.jpg
這是構造AC自動機 fail[]fail[] 數組的代碼:

void build(){
	for(int i=1;i<=26;i++) ch[0][i]=1;//把1節點的子節點同樣操作
	queue<int> q;while(q.size()) q.pop();q.push(1);//☆
	while(q.size()){//按bfs序求fail 指針
		int x=q.front();q.pop();
		for(int i=1;i<=26;i++)
			if(ch[x][i]) fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]);//☆
			else ch[x][i]=ch[fail[x]][i];//☆
	}
}

最後構造好 failfail 指針以後的字典樹就是AC自動機,長這樣:

acam5.jpg
然後就是重點了——應用 failfail 查找有幾個模式串在文本串中出現。代碼如下:

int fapp(char*s){
	int n=strlen(s+1),x=1,res=0;
	for(int i=1;i<=n;i++){
		x=ch[x][s[i]-'a'+1];
		for(int j=x;j&&mk[j]!=-1;j=fail[j])//mk置爲-1防止重複計算
			res+=mk[j],mk[j]=-1;
	}
	return res;
}

這時候你會很驚駭:這哪是失配跳轉啊,這分明就是指針亂飛!其實仔細想的話,其實是指針在整個AC自動機間穿梭(說了等於沒說),由於之前的紫色箭頭 ch[][]ch[][] 指針,指針表面上順着字典樹走的同時,也在自動末尾符失配跳轉,即單前字典樹節點如果沒有某個字符子節點,就會自動跳到有該字符的節點上或者根節點。而後面那句 failfail 指針跳轉的 for\texttt{for} 循環,就求出了單前節點到根節點所連成的字符串的後綴的出現次數。

然後如上一波猛如犇的操作以後,答案——模式串在文本串中出現的次數就出現了。如果你懂了,蒟蒻就放代碼了:

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
class Trie{
public:
	int ch[N][26],mk[N],cnt;
	Trie(){cnt=1;}
	void insert(char*s){
		int n=strlen(s+1),x=1;
		for(int i=1;i<=n;i++){
			int c=s[i]-'a'+1;
			if(!ch[x][c]) ch[x][c]=++cnt;
			x=ch[x][c];
		}
		mk[x]++;
	}
};
class Acam:public Trie{//Class 繼承
public:
	int fail[N];
	void build(){
		for(int i=1;i<=26;i++) ch[0][i]=1;
		queue<int> q;while(q.size()) q.pop();q.push(1);
		while(q.size()){
			int x=q.front();q.pop();
			for(int i=1;i<=26;i++)
				if(ch[x][i]) fail[ch[x][i]]=ch[fail[x]][i],q.push(ch[x][i]);
				else ch[x][i]=ch[fail[x]][i];
		}
	}
	int fapp(char*s){
		int n=strlen(s+1),x=1,res=0;
		for(int i=1;i<=n;i++){
			x=ch[x][s[i]-'a'+1];
			for(int j=x;j&&mk[j]!=-1;j=fail[j])
				res+=mk[j],mk[j]=-1;
		}
		return res;
	}
}m;
int num;
char s[N];
int main(){
	scanf("%d",&num);
	for(int i=1;i<=num;i++)
		scanf("%s",s+1),m.insert(s);
	m.build();
	scanf("%s",s+1);
	printf("%d\n",m.fapp(s));
	return 0;
}

可是如果字符串一多,字典樹一大,那麼那個重要的語句:

for(int j=x;j&&mk[j]!=-1;j=fail[j])
	res+=mk[j],mk[j]=-1;

反而會造成時間超限,如這道例題:

洛谷P5357 【模板】AC自動機(二次加強版)

如果你直接按上面的代碼改改,會 TLE75分\color{#333377}\texttt{TLE75分}

是時候優化如上穿梭指針語句了,那麼怎麼優化呢?我們發現如果把 fail[]fail[] 看成一些邊,就會構成一個 DAG\texttt{DAG} ,而答案更新又是按照 fail[]fail[] 數組跳指針的,這時我們必須有想到一個算法的直覺:拓撲排序。

因爲不跳 failfail 了,所以就不需要 mk[x]mk[x] 數組標記以 xx 節點爲結尾的字符串數了。因爲要拓撲排序,所以記錄每個字符串編號 ii 的終止節點 en[i]en[i]。所以 insert()\texttt{insert()} 函數長這樣:

void insert(char*s,int x){
	int n=strlen(s+1),p=1;
	for(int i=1;i<=n;i++){
		int c=s[i]-'a'+1;
		if(!ch[p][c]) ch[p][c]=++cnt;
		p=ch[p][c];
	}
	en[x]=p;
}

然後AC自動機的 failfail 構造函數不變,因爲要拓撲求答案,所以對於每個字符,只需要在該字符結尾的最長串加上標記即可。所以 fapp()\texttt{fapp()} 求答案函數要變成這樣:

void fapp(char*s){
	int n=strlen(s+1),p=1;
	for(int i=1;i<=n;i++)
		//mk[p=ch[p][s[i]-'a'+1]]++; 這麼寫對萌新不友好
		p=ch[p][s[i]-'a'+1],mk[p]++;
}

然後按 failfail 指針加反邊:

for(int i=2;i<=m.cnt;i++) g[m.fa[i]].push_back(i);

然後拓撲求答案即可:

void dfs(int x){
	for(auto to:g[x]) dfs(to),m.mk[x]+=m.mk[to];
}

最後對於每個字符串編號 iimk[en[i]]mk[en[i]] 就是該模式字符串在文本串中出現的次數。如果你都懂了,那麼蒟蒻就放代碼了:

#include <bits/stdc++.h>
using namespace std;
const int N=2e5+10,T=2e6+10;
class Trie{
public:
	int cnt,ch[N][30],en[N],mk[N];
	Trie(){cnt=1;}
	void insert(char*s,int x){
		int n=strlen(s+1),p=1;
		for(int i=1;i<=n;i++){
			int c=s[i]-'a'+1;
			if(!ch[p][c]) ch[p][c]=++cnt;
			p=ch[p][c];
		}
		en[x]=p;
	}
};
class Acam:public Trie{
public:
	int fa[N];
	void build(){
		for(int i=1;i<=26;i++) ch[0][i]=1;
		queue<int> q;while(q.size()) q.pop();q.push(1);
		while(q.size()){
			int x=q.front();q.pop();
			for(int i=1;i<=26;i++)
				if(ch[x][i]) fa[ch[x][i]]=ch[fa[x]][i],q.push(ch[x][i]);
				else ch[x][i]=ch[fa[x]][i];
		}	
	}
	void fapp(char*s){
		int n=strlen(s+1),p=1;
		for(int i=1;i<=n;i++)
			mk[p=ch[p][s[i]-'a'+1]]++;
	}
}m;
int n;
char s[T];
vector<int> g[N];
void dfs(int x){
	for(auto to:g[x]) dfs(to),m.mk[x]+=m.mk[to];
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%s",s+1),m.insert(s,i);
	m.build(),scanf("%s",s+1),m.fapp(s);
	for(int i=2;i<=m.cnt;i++) g[m.fa[i]].push_back(i);
	dfs(1);
	for(int i=1;i<=n;i++) printf("%d\n",m.mk[m.en[i]]);
	return 0;
}

學字符串數據結構之路(\texttt{★} 表示單前位置):

hash-kmp-manacher-exkmp-trie-acam-sa-sam-pam\color{#cccccc}\texttt{hash}\color{#aaaaff}\texttt{-}\color{#8888ff}\texttt{kmp}\color{#88cccc}\texttt{-}\color{#88ff88}\texttt{manacher}\color{#cccc88}\texttt{-}\color{#dddd44}\texttt{exkmp}\color{#eeaa44}\texttt{-}\color{#ffaa00}\texttt{trie}\color{#ff8800}\texttt{-}\color{#ee2200}\texttt{acam}\color{#000000}\texttt{★}\color{#ee0088}\texttt{-}\color{#cc00ff}\texttt{sa}\color{#660077}\texttt{-}\color{#555555}\texttt{sam}\color{#272727}\texttt{-}\color{#000000}\texttt{pam}

蒟蒻前途不可斗量,蒟蒻必須努力,不停下奮鬥的腳步。

祝大家學習愉快!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章