hdu3341
題意
給m個模式串和一個母串(字符串都是由ATCG組成),求這個母串重組後最多包含多少個模式串,可以重疊。
思路
看題目數據範圍很小,首先暴力的思想,把母串的所有可能的排列方式都求一遍,取最大值。建立模式串的AC自動機,dp[i][x1][x2][x3][x4]表示i狀態下有x1個A,x2個T,x3個C,x4個G。狀態轉移:dp[j][相應字母加一]=max(dp[j][相應字母加一],dp[i][x1][x2][x3][x4]+tag[trie[i][j]]),然後對於每個狀態取字符個數和爲n的最大值就可以了;但是!40404040500超內存,所以不行,考慮狀態壓縮:字符串總長40,當每種字符個數相等時,11 * 11 * 11 * 11表示的狀態是最大的,也就是長度40的字符串表示的狀態最多11^4,所以可以令MaxA、MaxC、MaxG、MaxT分別表示四種字符出現的個數,那麼T字符的權值爲1,G字符的權值爲(MaxT + 1),C字符的權值爲(MaxG + 1) * (MaxT + 1),A字符的權值爲(MaxC + 1) * (MaxG + 1) * (MaxT + 1),進行進制壓縮之後總的狀態數不會超過11^4,可以用DP[i][j]表示在trie的i號結點時ACGT四個字符個數的壓縮狀態爲j時的字符串包含模式串的最多數目,然後就是進行O(4500114)的狀態轉移了。
參考博客
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 500 + 10, M = 4;
int trie[MAXN][M];
int tag[MAXN];
int fail[MAXN];
int L = 0, root;
map<char, int> mp;
int newnode() {
for (int i = 0; i < M; i++) trie[L][i] = 0;
tag[L++] = 0;
return L - 1;
}
void init() {
L = 0;
root = newnode();
}
void insertWords(char *s)
{
int now = root, SIZE = strlen(s);
for (int i = 0; i < SIZE; i++) {
int next = mp[s[i]];
if(!trie[now][next])
trie[now][next] = newnode();
now = trie[now][next];
}
tag[now]++;
}
void getFail()//一個節點的fail指針是指向 這個節點表示的字符串的最長後綴串的最後一個節點
{
queue<int> q;
for(int i = 0; i < M; i++) {
if(trie[0][i]) {
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
while (!q.empty())
{
int now = q.front();
q.pop();
tag[now] += tag[fail[now]];
for (int i = 0; i < M; i++) {
if(trie[now][i]) {
fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else trie[now][i] = trie[fail[now]][i];
}
}
}
int num[4], bit[4];
int dp[11*11*11*11+1][510];
int main()
{
int n, Case = 1;
mp['A'] = 0, mp['T'] = 1, mp['C'] = 2, mp['G'] = 3;
while (scanf("%d", &n), n) {
init();
char s[50];
for (int i = 1; i <= n; i++)
scanf("%s", s), insertWords(s);
getFail();
scanf("%s", s);
memset(num, 0, sizeof(num));
int len = strlen(s);
for (int i = 0; i < len; i++) num[mp[s[i]]]++;
bit[0] = 1;
bit[1] = bit[0] * (num[0] + 1);
bit[2] = bit[1] * (num[1] + 1);
bit[3] = bit[2] * (num[2] + 1);
int status = num[0] * bit[0] + num[1] * bit[1] + num[2] * bit[2] + num[3] * bit[3];
memset(dp, -1, sizeof(dp));
dp[0][0] = 0;
for (int A = 0; A <= num[0]; A++) {
for (int T = 0; T <= num[1]; T++) {
for (int C = 0; C <= num[2]; C++) {
for (int G = 0; G <= num[3]; G++) {
int State = A * bit[0] + T * bit[1] + C * bit[2] + G * bit[3];
for (int j = 0; j < L; j++) {
if(dp[State][j] < 0) continue;
for (int k = 0; k < 4; k++) {
if(k == 0 && A == num[0]) continue;
if(k == 1 && T == num[1]) continue;
if(k == 2 && C == num[2]) continue;
if(k == 3 && G == num[3]) continue;
int nextnode = trie[j][k];
int nextState = State + bit[k];
dp[nextState][nextnode] = max(dp[nextState][nextnode], dp[State][j] + tag[nextnode]);
}
}
}
}
}
}
int ans = 0;
for (int i = 0; i < L; i++) ans = max(ans, dp[status][i]);
printf("Case %d: %d\n", Case++, ans);
}
}