hdu3341AC自動機+變進制狀壓dp

hdu3341
題意
給m個模式串和一個母串(字符串都是由ATCG組成),求這個母串重組後最多包含多少個模式串,可以重疊。
思路
看題目數據範圍很小,首先暴力的思想,把母串的所有可能的排列方式都求一遍,取最大值。建立模式串的AC自動機,dp[i][x1][x2][x3][x4]表示i狀態下有x1個A,x2個T,x3個C,x4個G。狀態轉移:dp[j][相應字母加一]=max(dp[j][相應字母加一],dp[i][x1][x2][x3][x4]+tag[trie[i][j]]),然後對於每個狀態取字符個數和爲n的最大值就可以了;但是!40404040500超內存,所以不行,考慮狀態壓縮:字符串總長40,當每種字符個數相等時,11 * 11 * 11 * 11表示的狀態是最大的,也就是長度40的字符串表示的狀態最多11^4,所以可以令MaxA、MaxC、MaxG、MaxT分別表示四種字符出現的個數,那麼T字符的權值爲1,G字符的權值爲(MaxT + 1),C字符的權值爲(MaxG + 1) * (MaxT + 1),A字符的權值爲(MaxC + 1) * (MaxG + 1) * (MaxT + 1),進行進制壓縮之後總的狀態數不會超過11^4,可以用DP[i][j]表示在trie的i號結點時ACGT四個字符個數的壓縮狀態爲j時的字符串包含模式串的最多數目,然後就是進行O(4500114)的狀態轉移了。
參考博客

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 500 + 10, M = 4;
int trie[MAXN][M];
int tag[MAXN];
int fail[MAXN];
int L = 0, root;
map<char, int> mp;
int newnode() {
    for (int i = 0; i < M; i++) trie[L][i] = 0;
    tag[L++] = 0;
    return L - 1;
}
void init() {
    L = 0;
    root = newnode();
}
void insertWords(char *s)
{
    int now = root, SIZE = strlen(s);
    for (int i = 0; i < SIZE; i++) {
        int next = mp[s[i]];
        if(!trie[now][next])
            trie[now][next] = newnode();
        now = trie[now][next];
    }
    tag[now]++;
}
void getFail()//一個節點的fail指針是指向 這個節點表示的字符串的最長後綴串的最後一個節點
{
    queue<int> q;
    for(int i = 0; i < M; i++) {
        if(trie[0][i]) {
            fail[trie[0][i]] = 0;
            q.push(trie[0][i]);
        }
    }
    while (!q.empty())
    {
        int now = q.front();
        q.pop();
        tag[now] += tag[fail[now]];
        for (int i = 0; i < M; i++) {
            if(trie[now][i]) {
                fail[trie[now][i]] = trie[fail[now]][i];
                q.push(trie[now][i]);
            }
            else trie[now][i] = trie[fail[now]][i];
        }
    }
}
int num[4], bit[4];
int dp[11*11*11*11+1][510];
int main()
{
    int n, Case = 1;
    mp['A'] = 0, mp['T'] = 1, mp['C'] = 2, mp['G'] = 3;
    while (scanf("%d", &n), n) {
        init();
        char s[50];
        for (int i = 1; i <= n; i++)
            scanf("%s", s), insertWords(s);
        getFail();
        scanf("%s", s);
        memset(num, 0, sizeof(num));
        int len = strlen(s);
        for (int i = 0; i < len; i++) num[mp[s[i]]]++;
        bit[0] = 1;
        bit[1] = bit[0] * (num[0] + 1);
        bit[2] = bit[1] * (num[1] + 1);
        bit[3] = bit[2] * (num[2] + 1);
        int status = num[0] * bit[0] + num[1] * bit[1] + num[2] * bit[2] + num[3] * bit[3];
        memset(dp, -1, sizeof(dp));
        dp[0][0] = 0;
        for (int A = 0; A <= num[0]; A++) {
            for (int T = 0; T <= num[1]; T++) {
                for (int C = 0; C <= num[2]; C++) {
                    for (int G = 0; G <= num[3]; G++) {
                        int State = A * bit[0] + T * bit[1] + C * bit[2] + G * bit[3];
                        for (int j = 0; j < L; j++) {
                            if(dp[State][j] < 0) continue;
                            for (int k = 0; k < 4; k++) {
                                if(k == 0 && A == num[0]) continue;
                                if(k == 1 && T == num[1]) continue;
                                if(k == 2 && C == num[2]) continue;
                                if(k == 3 && G == num[3]) continue;
                                int nextnode = trie[j][k];
                                int nextState = State + bit[k];
                                dp[nextState][nextnode] = max(dp[nextState][nextnode], dp[State][j] + tag[nextnode]);
                            }
                        }
                    }
                }
            }
        }
        int ans = 0;
        for (int i = 0; i < L; i++) ans = max(ans, dp[status][i]);
        printf("Case %d: %d\n", Case++, ans);
    }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章