題目
給出N個數,在裏面選出不超過K段連續的子序列,使其兩兩不相交,求總和的最大值(可以一段都不選)
數據範圍
N,K<= 100000
對於一個數a滿足 -1000000000 <= a <= 100000000
題解
首先看到這道題很容易想到是dp,然後再加上一個優化
可是這裏的N,K太大O(NK)是會超時的
所以換方法,然後用了一種不知道爲什麼的算法:
用線段樹維護,然後每次選出最大的一段子序列,加入答案中,再將這段區間選的數都乘上-1,一直重複K次或者是選到了最大的子序列是負數時,就退出
好像是用到了費用流的思想,然而我並想不出來
所以只會簡單的證明一下
證明:每次取得的區間只有兩種可能:
1.被之前選的區間之內
2.選擇一個還未選擇過的區間(這裏指的是從左端點到右端點一直沒有被選過)
證明
首先一段區間兩端一定都是正數(對於當前來說)
對於當前情況
如果有一個區間i與之前選過的區間j相交,那麼有一個端點是在之前選過的區間之內,且這個端點的值一定是正數
也不難發現,與之前選過的區間相交的這一段區間的和一定是大於0的,那就是說之前這個端點的值是負數,且相交區間和的應該是負數,那麼在第一次選這個j的時候時就可以選擇不要相交的區間,會使和最大,反證法可得命題不成立所以不會相交
同時i包含j也可以用反證法證明
代碼
#include <iostream>//對於這道題,首先是要維護區間最大值最小值,同時也要維護從最左端開始的最大
#include <cstdio>//與最小區間,右端點也要這樣,纔可以進行轉移
#include <cstring>
#include <cmath>
#include <vector>
using namespace std;
#define ll long long
const int MAXN = 1e5 + 3;
struct edge{
int l ,r;
ll sum;
edge(){}
edge( int L , int R , ll S ){
l = L;r = R;sum = S;
}
};
struct node{
int l , r;
edge zsum , lsum ,rsum;
edge zmin , lmin , rmin;
int lazy;
}tre[MAXN*4];
int n , K;
ll a[MAXN];
edge min_( edge a , edge b ){
edge ans;
if( a.sum >= b.sum )
ans = a;
else
ans = b;
return ans;
}
void push_( int i ){
tre[i].zsum = min_( tre[i*2].zsum , tre[i*2+1].zsum );
if( tre[i*2].rsum.sum + tre[i*2+1].lsum.sum > tre[i].zsum.sum ){
tre[i].zsum.sum= tre[i*2].rsum.sum + tre[i*2+1].lsum.sum;
tre[i].zsum.l = tre[i*2].rsum.l , tre[i].zsum.r = tre[i*2+1].lsum.r;
}
tre[i].lsum = tre[i*2].lsum;
if( tre[i*2].lsum.r - tre[i*2].lsum.l + 1 == tre[i*2].r - tre[i*2].l + 1 ){
if( tre[i].lsum.sum < tre[i*2].lsum.sum + tre[i*2+1].lsum.sum )
tre[i].lsum.sum = tre[i*2].lsum.sum + tre[i*2+1].lsum.sum , tre[i].lsum.l = tre[i].l , tre[i].lsum.r = tre[i*2+1].lsum.r;
}
tre[i].rsum = tre[i*2+1].rsum;
if( tre[i*2+1].rsum.r - tre[i*2+1].rsum.l + 1 == tre[i*2+1].r - tre[i*2+1].l + 1 ){
if( tre[i].rsum.sum < tre[i*2].rsum.sum + tre[i*2+1].rsum.sum )
tre[i].rsum.sum = tre[i*2].rsum.sum + tre[i*2+1].rsum.sum , tre[i].rsum.l = tre[i*2].rsum.l , tre[i].rsum.r =tre[i].r ;
}//
if( tre[i*2].zmin.sum <= tre[i*2+1].zmin.sum )
tre[i].zmin = tre[i*2].zmin;
else
tre[i].zmin = tre[i*2+1].zmin;
if( tre[i*2].rmin.sum + tre[i*2+1].lmin.sum < tre[i].zmin.sum ){
tre[i].zmin.sum = tre[i*2].rmin.sum + tre[i*2+1].lmin.sum;
tre[i].zmin.l = tre[i*2].rmin.l , tre[i].zmin.r = tre[i*2+1].lmin.r;
}
tre[i].lmin = tre[i*2].lmin;
if( tre[i*2].lmin.r - tre[i*2].lmin.l + 1 == tre[i*2].r - tre[i*2].l + 1 ){
if( tre[i].lmin.sum > tre[i*2].lmin.sum + tre[i*2+1].lmin.sum )
tre[i].lmin.sum = tre[i*2].lmin.sum + tre[i*2+1].lmin.sum , tre[i].lmin.l = tre[i].l , tre[i].lmin.r = tre[i*2+1].lmin.r;
}
tre[i].rmin = tre[i*2+1].rmin;
if( tre[i*2+1].rmin.r - tre[i*2+1].rmin.l + 1 == tre[i*2+1].r - tre[i*2+1].l + 1 ){
if( tre[i].rmin.sum > tre[i*2].rmin.sum + tre[i*2+1].rmin.sum )
tre[i].rmin.sum = tre[i*2].rmin.sum + tre[i*2+1].rmin.sum , tre[i].rmin.l = tre[i*2].rmin.l , tre[i].rmin.r =tre[i].r ;
}
}
void build( int i , int l , int r ){
tre[i].l = l , tre[i].r = r;tre[i].lazy = 1;
if( l == r ){
tre[i].zsum =edge( l , r , a[l] );tre[i].lsum =edge( l , r , a[l] );tre[i].rsum =edge( l , r , a[l] );
tre[i].zmin =edge( l , r , a[l] );tre[i].lmin =edge( l , r , a[l] );tre[i].rmin =edge( l , r , a[l] );
return ;
}
int mid = ( l + r ) / 2;
build( i * 2 , l , mid );
build( i * 2 + 1 , mid + 1, r );
push_( i );
}
void pushdown( int i ){
if( tre[i].lazy == -1 ){
tre[i*2].zsum.sum *= -1;tre[i*2].zmin.sum *= -1;
tre[i*2].lsum.sum *= -1;tre[i*2].lmin.sum *= -1;
tre[i*2].rsum.sum *= -1;tre[i*2].rmin.sum *= -1;
swap( tre[i*2].zsum , tre[i*2].zmin );
swap( tre[i*2].lsum , tre[i*2].lmin );
swap( tre[i*2].rsum , tre[i*2].rmin );
tre[i*2+1].zsum.sum *= -1;tre[i*2+1].zmin.sum *= -1;
tre[i*2+1].lsum.sum *= -1;tre[i*2+1].lmin.sum *= -1;
tre[i*2+1].rsum.sum *= -1;tre[i*2+1].rmin.sum *= -1;
swap( tre[i*2+1].zsum , tre[i*2+1].zmin );
swap( tre[i*2+1].lsum , tre[i*2+1].lmin );
swap( tre[i*2+1].rsum , tre[i*2+1].rmin );
tre[i].lazy = 1;
tre[i*2].lazy *= -1;tre[i*2+1].lazy *= -1;
}
}
void change( int i , int l , int r ){
if( tre[i].l > r || tre[i].r < l )return ;
if( l <= tre[i].l && r >= tre[i].r ){
tre[i].lazy *= -1;
tre[i].zsum.sum *= -1;tre[i].zmin.sum *= -1;
tre[i].lsum.sum *= -1;tre[i].lmin.sum *= -1;
tre[i].rsum.sum *= -1;tre[i].rmin.sum *= -1;
swap( tre[i].zsum , tre[i].zmin );
swap( tre[i].lsum , tre[i].lmin );
swap( tre[i].rsum , tre[i].rmin );
return ;
}
pushdown( i );
change( i * 2 , l , r );
change( i * 2 + 1 , l , r );
push_( i );
}
void read( ll &x ){
char s = getchar();int f=1;x= 0 ;
while( s < '0' || s > '9' ){
if( s == '-' )
f = -1;
s = getchar();
}
while( s >= '0' && s <= '9' ){
x = x * 10 + s - '0';
s = getchar();
}
x *= f;
}
int main()
{
freopen( "maxksum.in" , "r" , stdin );
freopen( "maxksum.out" , "w" , stdout );
scanf( "%d%d" , &n , &K );
for( int i = 1 ; i <= n ; i ++ ){
read( a[i] );
}
build( 1 , 1 , n );
ll ans = 0;
while( K -- ){
ll tot = tre[1].zsum.sum;
if( tot <= 0 ) break;
ans += tot;
change( 1 , tre[1].zsum.l , tre[1].zsum.r );
}
printf( "%lld" , ans );
return 0;
}