P4072 [SDOI2016]征途(斜率優化 dp + 細節)

在這裏插入圖片描述


設序列 a1,a2,a3,a4,...,ana_1,a_2,a_3,a_4,...,a_n 的平均值爲 xx,權值和爲 tottot,方差 S=i=1n(aix)2nS = \displaystyle\frac{\displaystyle\sum_{i = 1}^n(a_i -x)^2}{n}
乘上 n2n^2 得到:S=i=1n(maitot)2nS = \displaystyle\frac{\displaystyle\sum_{i = 1}^n(m*a_i -tot)^2}{n}

統計分子部分,最後再除以 n。

題目轉化過來就是將 n 個數分成 m 段,使得最後 SS 儘可能小。
每一段對 SS 的分子部分的貢獻不難計算,考慮 dp,dp[i][j] 表示前 i 個數分成 j 段對 S 的最小貢獻,答案顯然爲 dp[n][m]

轉移方程:dp[i][t]=minj=1i(dp[j][t1]+(m(sum[i]sum[j])tot)2)dp[i][t] = \displaystyle\min_{j = 1}^i(dp[j][t - 1] + (m * (sum[i] - sum[j]) - tot)^2)

注意初值,對所有的 i<j,dp[i][j]=infi <j,dp[i][j] = inf

直接做的複雜度是 n2mn^2m,考慮優化,將式子展開可以得到:

dp[i][t]=minj=1i(dp[j][t1]+m2sum[i]2+m2sum[j]22m2sum[i]sum[j]+tot22mtotsum[i]+2mtotsum[j])dp[i][t] = \displaystyle\min_{j = 1}^i(dp[j][t - 1] + m^2*sum[i]^2+m^2*sum[j]^2-2m^2*sum[i]*sum[j]+tot^2-2m*tot*sum[i]+2*m*tot*sum[j])

出現了 2m2sum[i]sum[j]2m^2*sum[i]*sum[j],考慮斜率優化,斜率爲 2msum[i]2m*sum[i] (移項後的斜率),斜率單調遞增,要求 dp[i][t]dp[i][t] 最小值,維護下凸包。

j>kj > k,在 jj 點比在 kk 點轉移更優,則滿足 :
dp[j][t1]+m2sum[i]2+m2sum[j]22m2sum[i]sum[j]+tot22mtotsum[i]+2mtotsum[j]dp[j][t - 1] + m^2*sum[i]^2+m^2*sum[j]^2-2m^2*sum[i]*sum[j]+tot^2-2m*tot*sum[i]+2*m*tot*sum[j] <<

dp[k][t1]+m2sum[i]2+m2sum[k]22m2sum[i]sum[k]+tot22mtotsum[i]+2mtotsum[k]dp[k][t - 1] + m^2*sum[i]^2+m^2*sum[k]^2-2m^2*sum[i]*sum[k]+tot^2 -2m*tot*sum[i]+2*m*tot*sum[k]

式子非常長,後面一段就是把 j 換成 k,消掉相同的並進行移項化簡,最後可以得到:

c[i][t]=dp[i][t]+m2sum[i]2+2mtotsum[i]c[i][t] = dp[i][t] + m^2*sum[i]^2 + 2m*tot*sum[i]

代入得:c[j][t1]c[k][t1]sum[j]sum[t]<2msum[i]\displaystyle\frac{c[j][t - 1] - c[k][t - 1]}{sum[j] - sum[t]} < 2m*sum[i]

用單調隊列維護 c[j][t1]c[k][t1]sum[j]sum[t]\displaystyle\frac{c[j][t - 1] - c[k][t - 1]}{sum[j] - sum[t]} 單增,根據決策單調性尋找最優轉移點,複雜度降爲 O(nm)O(n*m)

對第二維可以進行滾動,最後答案要記得除以 m。


代碼:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 3e3 + 10;
typedef long long ll;
const ll inf = 1e17;
int n,N;
ll sum[maxn],dp[maxn],tp[maxn],c[maxn],m[5],tot;
int q[maxn],front,rear;
ll calc(int x,int y) {
	ll tmp = m[1] * (sum[x] - sum[y]) - tot;
	return tp[y] + tmp * tmp;
}
int main() {
	scanf("%d%d",&n,&m[1]);
	for (int i = 2; i <= 4; i++)
		m[i] = 1ll * m[i - 1] * m[1];
	for (int i = 1,x; i <= n; i++)
		scanf("%d",&sum[i]);
	for (int i = 1; i <= n; i++) 
		sum[i] += sum[i - 1];
	tot = sum[n];
	for (int j = 1; j <= n; j++)
		tp[j] = dp[j] = inf;
	for (int t = 1; t <= m[1]; t++) {
		front = rear = 0;
		q[++rear] = t - 1;
		for (int i = t; i <= n; i++) {
			while (front + 1 < rear && calc(i,q[front + 1]) >= calc(i,q[front + 2]))
				front++;
			dp[i] = calc(i,q[front + 1]);
			while (front + 1 < rear && (c[i] - c[q[rear]]) * (sum[q[rear]] - sum[q[rear - 1]]) 
				<= (c[q[rear]] - c[q[rear - 1]]) * (sum[i] - sum[q[rear]]))
				rear--;
			q[++rear] = i;
		}
		for (int i = 0; i <= n; i++) {
			tp[i] = dp[i];
			c[i] = tp[i] + m[2] * sum[i] * sum[i] + 2 * m[1] * tot * sum[i];
			dp[i] = inf;
		}
	}	
	printf("%lld\n",tp[n] / m[1]);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章