orz zhf
題意
有n個病毒,每天每個病毒的體積會變大ai,每天必須且只能消除一個病毒,代價是病毒的體積,每個病毒的初始體積是bi,天數一共有k天,問最小的代價是多少。
數據範圍
解法
首先有一個比較顯然的dp,設f[i][j]表示前i天,一共清除了j個病毒的最小代價,轉移顯然,然後需要注意的是如果確定了一個選擇病毒的集合,那麼一定是按ai從大到小的順序清除病毒。所以可以事先將病毒按ai從大到小排好序.
代碼:此處f數組優化了一維,path數組表示的是具體方案
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=2e3+5;
inline int read(){
char c=getchar();int t=0,f=1;
while((!isdigit(c))&&(c!=EOF)){if(c=='-')f=-1;c=getchar();}
while((isdigit(c))&&(c!=EOF)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,k;
struct node{
int a,b;
}a[maxn];
bool cmp(node a,node b){
return a.a==b.a?a.b<b.b:a.a>b.a;
}
int f[maxn],path[maxn][maxn];
signed main(){
freopen("12.in","r",stdin);
freopen("12b.out","w",stdout);
n=read(),k=read();
if(k>n)k=n;
for(int i=1;i<=n;i++){
a[i].a=read(),a[i].b=read();
}
sort(a+1,a+1+n,cmp);
memset(f,0x3f,sizeof(f));
f[0]=0;
for(int i=1;i<=n;i++){
for(int j=k;j>=1;j--){
if(f[j-1]+(j-1)*a[i].a+a[i].b<f[j]){
path[i][j]=1;
f[j]=f[j-1]+(j-1)*a[i].a+a[i].b;
}
}
}
/*int x=n,y=k;
while(x>=1){
if(path[x][y]){printf("%lld %lld\n",a[x].a,a[x].b);y--;}
x--;
}*/
printf("%lld\n",f[k]);
return 0;
}
然後觀察path數組的性質,可以發現一個病毒如果被選入了f(n,k)的答案集合,那麼一定也會被選入f(n,k+1)的答案集合,所以根據這個性質我們考對每個病毒二分一個最早被選入答案集合的時刻。這個是一個筆者認爲比較麻煩的問題,具體要解決的問題有:二分的條件,動態維護需要的信息,幸運的是平衡樹可以維護這些東西。
首先我們考慮二分的條件:一個病毒被選入答案集合中,可以認爲是原先有一個答案集合,現在要將一個病毒插進去,考慮一個位置i,現在答案集合中的該位置的數是a,b。那麼這個病毒對答案造成的影響是第一維比a小的都會貢獻一個第一維的值,然後這個a會被算(第一維比a大的次數),注意這裏第一維和a相等的隨便排就可以,可以不用特意處理。然後新加入的病毒產生的貢獻是類似的,但是注意新加入的病毒和原有的答案集合之間的信息並不好維護,所以我們也需要先將病毒按a從大到小排好序,這樣,每次新加入的病毒一定比答案集合中其它的病毒的第一維更小,就會方便很多。
然後有關原來的答案集合,每個病毒維護兩個信息,比它第一維小的病毒數,以及這些病毒的第一維的數值之和,這個可以用平衡樹簡單維護。然後我們就有了代碼:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e6+5;
inline ll read(){
char c=getchar();ll t=0,f=1;
while((!isdigit(c))&&(c!=EOF)){if(c=='-')f=-1;c=getchar();}
while((isdigit(c))&&(c!=EOF)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,k;
struct node{
int a;
ll b;
}a[maxn];
int b[maxn];
bool cmp(node a,node b){
return a.a==b.a?a.b>b.b:a.a>b.a;
}
int ans;
struct tree{
int val,l,r,sz,tag;
ll tot,tag2;
}t[maxn];
int tot,rt,p1,p2;
#define ls t[rt].l
#define rs t[rt].r
inline void pushup(int rt){
t[rt].sz=t[ls].sz+t[rs].sz+1;
}
inline int newnode(int val){tot++;
t[tot].val=val;t[tot].l=t[tot].r=0;t[tot].sz=1;t[tot].tag=t[tot].tot=t[tot].tag2=0;
return tot;
}
inline void pushdown(int rt){
if(t[rt].tag){
t[ls].tag+=t[rt].tag;
t[rs].tag+=t[rt].tag;
t[ls].tot-=1ll*t[rt].tag*a[t[ls].val].a;
t[rs].tot-=1ll*t[rt].tag*a[t[rs].val].a;
t[rt].tag=0;
}
if(t[rt].tag2){
t[ls].tag2+=t[rt].tag2;
t[rs].tag2+=t[rt].tag2;
t[rs].tot+=t[rt].tag2;
t[ls].tot+=t[rt].tag2;
t[rt].tag2=0;
}
}
inline void split(int rt,int &l,int &r,int k){
if(!rt){l=r=0;return ;}
pushdown(rt);
if(t[ls].sz>=k){
r=rt;
split(ls,l,ls,k);
}
else{
l=rt;
split(rs,rs,r,k-t[ls].sz-1);
}
pushup(rt);
}
inline int merge(int x,int y){
if((!x)||(!y))return x^y;
pushdown(x);
pushdown(y);
if(t[x].sz>t[y].sz){
t[x].r=merge(t[x].r,y);
pushup(x);
return x;
}
else{
t[y].l=merge(x,t[y].l);
pushup(y);
return y;
}
}
inline int find(int k){
int u=rt,num=0;
while(1){
if(!u)return num;
pushdown(u);
int va=t[u].val;
ll tmp1=1ll*a[va].a*(num+t[t[u].l].sz)+a[va].b+t[u].tot,tmp2=1ll*a[k].a*(num+t[t[u].l].sz)+a[k].b;
if(tmp1>tmp2){
u=t[u].l;
}
else if(tmp1==tmp2){
return num;
}
else{
num=num+t[t[u].l].sz+1;
u=t[u].r;
}
}
}
inline void modify(int rt,int k,int val){
pushdown(rt);
if(t[ls].sz>=k){t[rs].tag++;t[rs].tot+=val-a[t[rs].val].a;t[rs].tag2+=val;t[rt].tot+=val-a[t[rt].val].a;modify(ls,k,val);}
else if(t[ls].sz+1==k){t[rt].tot+=val-a[t[rt].val].a;t[rs].tag++;t[rs].tag2+=val;t[rs].tot+=val-a[t[rs].val].a;return ;}
else modify(rs,k-t[ls].sz-1,val);
}
void insert(int i){
int sz=find(i);
split(rt,p1,p2,sz);
rt=merge(p1,merge(newnode(i),p2));
if(t[rt].sz>=sz+2)
modify(rt,sz+2,a[i].a);
}
inline int query(int k){
int u=rt;
while(1){
if(t[t[u].l].sz>=k)u=t[u].l;
else if(t[t[u].l].sz+1==k)return t[u].val;
else{k-=t[t[u].l].sz+1;u=t[u].r;}
}
}
signed main(){
//freopen("12.in","r",stdin);
//freopen("12.out","w",stdout);
n=read(),k=read();
if(k>n)k=n;
for(int i=1;i<=n;i++){
a[i].a=read(),a[i].b=read();
}
sort(a+1,a+1+n,cmp);
for(int i=1;i<=n;i++){
insert(i);
}
ll ans=0;
for(int i=1;i<=k;i++){
int tmp=query(i);
b[i]=a[tmp].a;
ans=ans+a[tmp].b;
}
sort(b+1,b+1+k);
for(int i=k;i>=1;i--){
ans=ans+1ll*b[i]*(k-i);
}
printf("%lld\n",ans);
return 0;
}
//經歷了長時間的卡常,終於卡過了。。。
時間複雜度