HDU6058 A Kanade's sum

題目鏈接

題意

​ 給定一個長度爲 n 的數組 A,用 A[1...n] 表示,A[1...n]1n 的數一種排列組合。存在一個函數 f(l,r,k) 表示 A[l...r] 中第k大數的值,同時 f(l,r,k)=0rl+1<k 。給定 k 求解nl=1nr=lf(l,r,k)

分析

​ 考慮每個數對答案的貢獻,顯然 A[1...n] 中第 1 到第 k1 大的數對答案的貢獻均爲0。然後思考剩下的數。對於數 ai ,想要使 ai 爲選定 A[l...r] 中第 k 大數,必須要使 A[l...r] 中恰好存在 k1 個大於 ai 的數。如果我們已知大於 ai 的每個數的具體位置,那麼就從 ai 的左邊選取 t 個比它大的數,再從 ai 的右邊選取 kt1 個比它大的數,就能組合成一種可行解。實際選取這樣 k1 個數不一定僅有一種解,對於左邊,第 t 個數到第 t+1 個數之間的所有位置均是合法的選擇,右邊類似。設左邊選 t 個數的合法區間長度爲 lef[t] ,右邊選 k1t 個數的合法區間長度爲 rig[k1t] ,則總共的可行解數爲 lef[t]×rig[k1t] 。最後只要枚舉一下t的值就能快速求解了。

​ 新的問題是如何快速確定大於 ai 的每個數的位置,或者說快速逐個搜索 ai 左側和右側比 ai 大的數的位置。由於不需要利用到比 ai 小的數,不妨從大到小枚舉 1n 中的每一個數,處理完後將對應數的位置放入某個集合中,這樣每次查詢這個集合時,必然都是比 ai 大的數。然而集合中快速搜索最近位置的點的複雜度爲 O(log(n)) ,加上枚舉的複雜度,總複雜度達到了 O(nklog(n)) 。賽時嘗試了一發,不出意料的T掉了。然後考慮到優化搜索過程,不難發現搜索時都是從 ai 的位置向左或者向右逐個查詢,於是想到用鏈表的方式保存左邊和右邊最近點位置。通過鏈表的指針就能在枚舉的過程中快速搜索下一個點的位置了。插入新位置時採用二分的方式找到最近的點的位置,再利用鏈表關係更新即可。最後的總複雜度爲 O(nk+nlog(n))

代碼

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<map>
using namespace std;
#define LL long long
#define MAXN 500500
const int mod=1e9+7;
struct Node{
    int l,r;
}nxt[MAXN];
int bin[MAXN];
int pos[MAXN];
int lef[MAXN];
int rig[MAXN];
int n;
int lowbit(int x){
    return x&-x;
}
void add(int x){
    while(x<=n){
        bin[x]++;
        x+=lowbit(x);
    }
}
int sum(int x){
    int ret=0;
    while(x){
        ret+=bin[x];
        x-=lowbit(x);
    }
    return ret;
}
int query(int l,int r){
    return sum(r)-sum(l-1);
}
void updata(int x){
    int l=1,r=x-1,mid;
    while(l<=r){
        mid=(l+r)>>1;
        if(query(mid,x-1)<1)
            r=mid-1;
        else
            l=mid+1;
    }
    nxt[x].l=r;
    if(r>0)
        nxt[r].r=x;
    l=x+1,r=n;
    while(l<=r){
        mid=(l+r)>>1;
        if(query(x+1,mid)<1)
            l=mid+1;
        else
            r=mid-1;
    }
    nxt[x].r=l;
    if(l<=n)
        nxt[l].l=x;
}
int main(){
    int T,k,a;
    cin>>T;
    while(T--){
        scanf("%d %d",&n,&k);
        memset(bin,0,sizeof(bin));
        nxt[0].l=nxt[n+1].l=0;
        nxt[0].r=nxt[n+1].r=n+1;
        for(int i=1;i<=n;++i){
            scanf("%d",&a);
            pos[a]=i;
            nxt[i].l=0;
            nxt[i].r=n+1;
        }
        for(int i=n;i>n-k+1;i--){
            updata(pos[i]);
            add(pos[i]);
        }
        LL ans=0;
        for(int i=n-k+1;i;i--){
            updata(pos[i]);
            add(pos[i]);
            for(int j=0,curl=pos[i],curr=pos[i];j<=k;++j){
                lef[j]=curl-nxt[curl].l;
                rig[j]=nxt[curr].r-curr;
                curl=nxt[curl].l;
                curr=nxt[curr].r;
            }
            for(int j=0;j<k;++j)
                ans+=i*1ll*lef[j]*rig[k-j-1];
        }
        printf("%I64d\n",ans);
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章