字典樹
字典樹是一種處理前綴的數據結構
略懂數據結構的人,相信看完下面這張圖就差不多理解了
- 的根節點是空的
(相信沒有題目給的所有數據有公共前綴) - 每個節點儲存一個單詞/字母
- 根節點到每個單詞節點的路徑上的所有字母連接而成的字符串就是該節點對應的字符串
- 空間換時間的方法(1秒一般能解決的總字符數量在100000個,空間一般也開800000,適用於查詢比較多的情況)
實現
以小寫字母爲例,講解以下代碼實現
節點
需要一個數組存儲節點信息
外加一個mark標誌,存儲答案
struct Tree{
Tree() { //構造函數
mark = 0;
memset(son, 0, sizeof(son));
}
void clear() {
mark = 0;
memset(son, 0, sizeof(son));
}
int mark; //標記,一般爲詢問的值,視情況而定
int son[26]; //此處只考慮小寫字母,即爲節點存儲的數據
}tree[maxn];
插入
逐層迭代,將沒有的節點插入
在末尾節點的下一個節點存儲該字符串的信息(因爲由這個節點前面構成的字符串是唯一的)
int root, num; //根節點永久爲0
void insert(char* str) {
int position = root; //初始化位置
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a'; //轉化函數,視情況而定
if (!tree[position].son[symbol])//創建新節點
tree[position].son[symbol] = ++num;
position = tree[position].son[symbol];
}
tree[position].mark++; //記錄鏈末尾的數量
}
查詢
逐層遞歸查找
int find(char* str) {
int position = root;
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a';
if (!tree[position].son[symbol]) //找不到,返回false
return 0;
position = tree[position].son[symbol];//迭代尋找
}
return tree[position].mark; //返回相同鏈的數量
}
實際上,我們發現字典樹是種很簡單的數據結構,上面的插入和查詢只是基本操作
當你理解上面的插入和查詢操作後,也可以輕鬆實現其他操作
例題
P2580 於是他錯誤的點名開始了
題意
給出一組名字(好幾個字符串)
進行m次點名,第一次字符串出現輸出“OK”,第二次以上出現輸出“REPEAT”,若該字符串不在給出字符串內輸出“WRONG”
思路
板子題+1
mark記錄該字符串是否出現過
記錄該字符串被訪問次數
P3879 [TJOI2010]閱讀理解
題意
給出n組字符串,每組字符串有1個或多個字符串
有m組詢問,詢問爲一個字符串,返回在哪幾組字符串中出現過
思路
板子題+2
用vector存儲一個字符串在哪幾組字符串中出現過
P5149 會議座位
題意
按順序給出個字符串
第一個字符串最小,第二個第二,以此類推
再次給出n個字符的順序,求逆序對數
思路
板子題+3
字典樹離散化字符串
歸併排序求逆序對數
P3294 [SCOI2016]背單詞
題意
給你n個字符串,不同的排列有不同的代價,代價按照如下方式計算(字符串s的位置爲x)
1.排在s後面的字符串有s的後綴,則代價爲n^2
2.排在s前面的字符串有s的後綴,且沒有排在s後面的s的後綴,則代價爲x-y(y爲最後一個與s不相等的後綴的位置);
3.s沒有後綴,則代價爲x
求最小代價和
思路
顯然第一種代價是不可能被用到的
只要按後綴逐個取字符就能避免第一種情況
處理後綴比較麻煩,我們將字符串翻轉,就可以用字典樹存儲前綴
按前綴逐個取字符串
容易證明先取後綴少的字符串,最終代價少
先重構樹,只保留存在的字符串
先處理各個子樹的節點數
再根據子樹權重,優先對權重小的子樹搜索即可
//具體過程,類似樹鏈剖分
例題代碼
P2580 於是他錯誤的點名開始了
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
#pragma warning (disable:4996)
const int maxn = 1000005;
struct Tree{
Tree() {
mark = 0;
memset(son, 0, sizeof(son));
}
void clear() {
mark = 0;
memset(son, 0, sizeof(son));
}
int mark; //標記
int son[26]; //此處只考慮小寫字母
}tree[maxn];
int root, num; //根節點永久爲0
void insert(char* str) {
int position = root; //初始化位置
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a'; //轉化函數,視情況而定
if (!tree[position].son[symbol])//創建新節點
tree[position].son[symbol] = ++num;
position = tree[position].son[symbol];
}
tree[position].mark = 1; //記錄鏈末尾的數量
}
int find(char* str) {
int position = root;
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a';
if (!tree[position].son[symbol]) //找不到,返回false
return 0;
position = tree[position].son[symbol];//迭代尋找
}
return tree[position].mark++; //返回相同鏈的數量
}
char s[65];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m; cin >> n;
for (int i = 1; i <= n; i++) {
cin >> s;
insert(s);
}
cin >> m;
while (m--) {
cin >> s;
int res = find(s);
if (res == 0)cout << "WRONG\n";
else if (res == 1)cout << "OK\n";
else cout << "REPEAT\n";
}
}
P3879 [TJOI2010]閱讀理解
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
#pragma warning (disable:4996)
const int maxn = 1000005;
struct Tree{
Tree() {
memset(son, 0, sizeof(son));
}
void clear() {
id.clear();
memset(son, 0, sizeof(son));
}
vector<int> id;
int son[26]; //此處只考慮小寫字母
}tree[maxn];
int root, num; //根節點永久爲0
void insert(char* str, int x) {
int position = root; //初始化位置
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a'; //轉化函數,視情況而定
if (!tree[position].son[symbol])//創建新節點
tree[position].son[symbol] = ++num;
position = tree[position].son[symbol];
}
if (tree[position].id.size() == 0 ||
x != tree[position].id[tree[position].id.size() - 1])
tree[position].id.push_back(x);
}
void find(char* str) {
int position = root;
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a';
if (!tree[position].son[symbol]) //找不到,返回false
{
printf("\n");
return;
}
position = tree[position].son[symbol];//迭代尋找
}
for (int i = 0; i < tree[position].id.size(); i++)
if (i)printf(" %d", tree[position].id[i]);
else printf("%d", tree[position].id[i]);
printf("\n");
}
char s[30];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m; cin >> n;
for (int i = 1; i <= n; i++) {
cin >> m;
while (m--) {
cin >> s;
insert(s, i);
}
}
cin >> m;
while (m--) {
cin >> s;
find(s);
}
}
P5149 會議座位
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
#pragma warning (disable:4996)
typedef long long LL;
const int maxn = 100005;
struct Tree{
Tree() {
mark = 0;
memset(son, 0, sizeof(son));
}
void clear() {
mark = 0;
memset(son, 0, sizeof(son));
}
int mark; //標記
int son[52]; //此處只考慮小寫字母
}tree[maxn << 2];
int root, num; //根節點永久爲0
void insert(char* str, int x) {
int position = root; //初始化位置
for (int i = 0; str[i]; i++) {
int symbol;
if (str[i] >= 'A' && str[i] <= 'Z')symbol = str[i] - 'A';
else symbol = str[i] - 'a' + 26;
if (!tree[position].son[symbol])//創建新節點
tree[position].son[symbol] = ++num;
position = tree[position].son[symbol];
}
tree[position].mark = x; //記錄鏈末尾的數量
}
int find(char* str) {
int position = root;
for (int i = 0; str[i]; i++) {
int symbol;
if (str[i] >= 'A' && str[i] <= 'Z')symbol = str[i] - 'A';
else symbol = str[i] - 'a' + 26;
if (!tree[position].son[symbol]) //找不到,返回false
return 0;
position = tree[position].son[symbol];//迭代尋找
}
return tree[position].mark; //返回相同鏈的數量
}
char s[20];
int a[maxn];
LL ans = 0;
void merge(int left, int right, int stdl, int stdr) {
if (left == stdr) return;
int i = left, j = stdl, k = 0;
int* t = new int[stdr - left + 1];
while (i <= right && j <= stdr) {
if (a[i] < a[j]) t[k++] = a[i++];
else {
t[k++] = a[j++];
ans = ans + right - i + 1;
}
}
while (i <= right)t[k++] = a[i++];
while (j <= stdr)t[k++] = a[j++];
for (int i = left; i <= stdr; i++)a[i] = t[i - left];
delete []t;
}
void merge_sort(int left, int right) {
if (left == right) return;
int mid = (left + right) >> 1;
merge_sort(left, mid);
merge_sort(mid + 1, right);
merge(left, mid, mid + 1, right);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n; cin >> n;
for (int i = 1; i <= n; i++) {
cin >> s;
insert(s, i);
}
for (int i = 1; i <= n; i++) {
cin >> s;
a[i] = find(s);
}
merge_sort(1, n);
printf("%lld\n", ans);
}
P3294 [SCOI2016]背單詞
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
#pragma warning (disable:4996)
typedef long long LL;
const int maxn = 100005;
struct Tree{
Tree() {
mark = 0;
memset(son, 0, sizeof(son));
}
void clear() {
mark = 0;
memset(son, 0, sizeof(son));
}
bool mark; //標記
int son[26]; //此處只考慮小寫字母
}tree[maxn << 3];
int root, num; //根節點永久爲0
void insert(char* str) {
int position = root; //初始化位置
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a'; //轉化函數,視情況而定
if (!tree[position].son[symbol])//創建新節點
tree[position].son[symbol] = ++num;
position = tree[position].son[symbol];
}
tree[position].mark = true; //記錄鏈末尾的數量
}
int find(char* str) {
int position = root;
for (int i = 0; str[i]; i++) {
int symbol = str[i] - 'a';
if (!tree[position].son[symbol]) //找不到,返回false
return 0;
position = tree[position].son[symbol];//迭代尋找
}
return tree[position].mark; //返回相同鏈的數量
}
char s[maxn * 6];
vector<int> E[maxn]; int cnt;
void struct_dfs(int root, int f) {
for (int i = 0; i < 26; i++) {
if (!tree[root].son[i]) continue;
if (!tree[tree[root].son[i]].mark)
struct_dfs(tree[root].son[i], f);
else {
E[f].push_back(++cnt);
struct_dfs(tree[root].son[i], cnt);
}
}
}
int n_size[maxn];
bool cmp(int x, int y) {
return n_size[x] < n_size[y];
}
void dfs1(int now, int f) {
n_size[now] = 1;
for (int i = 0; i < E[now].size(); i++) {
dfs1(E[now][i], now);
n_size[now] += n_size[E[now][i]];
}
sort(E[now].begin(), E[now].end(), cmp);
}
int dfn, d[maxn];
LL ans = 0;
void dfs2(int now, int f) {
d[now] = ++dfn;
ans = ans + d[now] - d[f];
for (int i = 0; i < E[now].size(); i++)
dfs2(E[now][i], now);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n; cin >> n;
while (n--) {
cin >> s;
reverse(s, s + strlen(s));
insert(s);
}
struct_dfs(0, 0);
dfs1(0, -1);
dfs2(0, 0);
cout << ans << '\n';
}