HDU - 5788 Level Up 主席樹+樹狀數組

我們先說明幾個結論:

  1. 一個人的能力變化只會影響自己和上司
  2. 當變化的下屬的能力小於等於上司的中位數 那麼這個中位數會向後移動一位 否則不變 (這個畫圖就能體會到)

所以我們需要用主席樹去查詢 第mid個能力值 和 第mid+1個能力值 用樹狀數組去維護差值 枚舉每一個點 並記錄可以獲得的最大差值 最後答案就是初始的能力值之和+最大差值 具體可以看代碼解釋

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5+10,M = 1e5;
ll c[N],sum[N*30],ans,mx;
int L[N*30],R[N*30],tot,rt[N],a[N],val[N],mid[N];
int h[N],to[N<<1],nex[N<<1],cur;
//a是初始能力值 val[i]是變化i結點後結點i得到的差值 mid是中值  ans初始能力值之和 mx是最大差值 
int siz[N],dfn[N],cnt,n;
void init(){//清空 
	ans=mx=cur=tot=cnt=0;
	memset(h,0,sizeof(h));
}
void add_edge(int x,int y){
	to[++cur]=y;nex[cur]=h[x];h[x]=cur;
}
void add(int x,ll val){//樹狀數組 
	while(x<=M){
		c[x]+=val;
		x+=x&-x;
	}
}
ll query(int x){
	ll ret = 0;
	while(x){
		ret+=c[x];
		x-=x&-x;
	}return ret;
}
void update(int &rt,int lasrt,int l,int r,int pos){//主席樹 
	rt=++tot;sum[rt]=sum[lasrt]+1;
	if(l==r) return;
	L[rt]=L[lasrt],R[rt]=R[lasrt];
	int mid = l+r>>1;
	if(pos<=mid) update(L[rt],L[lasrt],l,mid,pos);
	else {update(R[rt],R[lasrt],mid+1,r,pos);} 
}
int Query(int ql,int qr,int l,int r,int k){
	if(l==r) return l;
	int o = sum[L[qr]]-sum[L[ql]],mid = l+r>>1;
	if(k<=o) return Query(L[ql],L[qr],l,mid,k);
	else {return Query(R[ql],R[qr],mid+1,r,k-o);}
}
void dfs1(int u){
	siz[u]=1,dfn[u]=++cnt;//計算大小和dfs序 
	update(rt[dfn[u]],rt[dfn[u]-1],1,M,a[u]);
	for(int i = h[u]; i; i = nex[i]) dfs1(to[i]),siz[u]+=siz[to[i]];
	if(!h[u]){//如果這是一個葉子結點 
		mid[u]=a[u],val[u]=M-mid[u];//中值就是自己 差值就是1e5-自己 
		ans+=a[u];
	}else{
		int k = siz[u]+1>>1;
		mid[u]=Query(rt[dfn[u]-1],rt[cnt],1,M,k);//找到中值 
		val[u]=Query(rt[dfn[u]-1],rt[cnt],1,M,k+1)-mid[u];//差值就是 第mid+1個位置的值-中值 
		ans+=mid[u];
	}
}
void dfs2(int u){
	add(mid[u],val[u]);//遍歷到一個點時我們把差值放入樹狀數組的mid[u]位置 
	mx=max(mx,query(M)-query(a[u]-1));//如果當前點u是產生變化的點 
	//那麼差值就是他的所有上司裏面mid[]值 大於等於a[u]的人的val[]值(差值)之和 
	for(int i = h[u]; i; i = nex[i]) dfs2(to[i]);
	add(mid[u],-val[u]);//當我搜完所有兒子就可以撤掉我的差值了   
}
int main(){
	while(~scanf("%d",&n)){
		init();
		for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
		for(int i = 2; i <= n; i++){
			int x;
			scanf("%d",&x);
			add_edge(x,i);
		}
		dfs1(1);dfs2(1);
		printf("%lld\n",ans+mx);
	}
	return 0;
}

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章