問題描述
Say you have an array for which the ith element is the price of a given stock on day i.
Design an algorithm to find the maximum profit. You may complete at most k transactions.
Example
Given prices = [4,4,6,1,1,4,2,5], and k = 2, return 6.
簡而言之,就是從prices數組中,取出最多2*k個數,相鄰的每對(比如第1個數和第2個數,第3個數和第4個數),其總和要最大。
Note
You may not engage in multiple transactions at the same time (i.e., you must sell the stock before you buy again).
題意分析
這是一個DP (Dynamic Programming)問題。
解法1
t(n, k):表示前n天,進行最多k次交易,所能獲得的最大收益。並且最後一次交易發生在第n天
那麼有:
t(n, k) = max { t(m, k-1) + prices[n] - prices_min(m+1, n) }, 0 <= m <= n-1
t(n, k) = max { t(n, k), t(n, k-1) } // 只最多進行 k-1 次交易
其中 prices_min(m+1, n) 表示第m+1天到第n天之間的最低股票價格
這樣的話,可以對m從n-1變換到0,進行掃描一次,並且更新prices_min,這樣的話,要求得到t(n, k) 的複雜度爲 o(n),因此總的時間複雜度爲 O(k * n^2)
解法2
解法一的複雜度還是有點大,當k趨近於n的時候,複雜度可以達到 O(n^3),不太可以接受。因此我們需要一個更好的dp方法。
g(n, k)表示前n天,進行k次交易,所獲得最大收益
l(n, k) 表示前n天,進行k次交易,且最後一天進行了交易(賣出股票),所獲得最大收益
g(n, k) = max{ g(n-1, k), l(n, k) } // 前n天最大收益,要不最後一次交易在前n-1天內,要不發生在第n天
l(n, k) = max{ g(m, k-1) + prices[n] - prices[m+1] }, 0 <= m <= n-1
l(n, k) = max{ l(n, k), l(n, k-1) }
在計算 l(n, k) 的時候,需要遍歷 0 ~ n-1,如果暴力的話,複雜度還是 O(n),最後的複雜度就是O(k * n^2)
那麼這裏有個優化算法:
設max_diff 爲 max { g(m, k-1) - prices[m+1] }, 0 <= m <= n-1。在計算 l(0, k) ~ l(n, k) 的時候,不斷更新max_diff,那麼這樣的話,其實就把計算 l(n,k) 的平攤複雜度降了下來,複雜度爲 O(1)。那麼總的時間複雜度爲 O(kn)。
具體代碼:
int maxProfit(int k, vector<int> &prices) {
// write your code here
int n = prices.size();
if (k > n/2) k = n/2;
if (n <= 1) return 0;
int *l = new int[n+1];
int *g = new int[n+1];
for (int i = 1; i <= k; i++) {
if (i == 1) {
int current_min = prices[0], max_profit = 0;
for (int j = 0; j <= n; j++) {
if (j == 0) { l[j] = g[j] = 0; continue; }
current_min = min(current_min, prices[j-1]);
max_profit = max(max_profit, prices[j-1]- current_min);
l[j] = prices[j-1] - current_min;
g[j] = max_profit;
}
// print(l, n+1); print(g, n+1); printf("---\n");
} else {
int max_diff = g[0] - prices[0];
l[0] = g[0] = 0;
for (int j = 1; j <= n; j++) {
max_diff = max(max_diff, g[j-1] - prices[j-1]);
// printf("j:%d, max_diff:%d\n ", j, max_diff);
l[j] = max(l[j], max_diff + prices[j-1]);
}
// cout << endl;
// update g
for (int j = 1; j <= n; j++) {
g[j] = max(g[j-1], l[j]);
}
// print(l, n+1); print(g, n+1); printf("---\n");
}
}
return g[n];
}