樹鏈剖分的一些總結(TO DO)

題目鏈接:https://www.luogu.org/problemnew/show/P3384

樹鏈剖分的核心思想就是將一個樹上問題轉化爲鏈上問題,然後就可以用線段樹解決啦。

這裏有一個重鏈和輕鏈的概念。找到每個節點的重兒子作爲他的son存下來。將每個點和他的重兒子們作爲一條鏈搞下去。

#include <bits/stdc++.h>

using namespace std;
const int maxn=400005;
int n,m,root,mod;
int tot,cnt;
int u[maxn],v[maxn],head[maxn],nxt[maxn],dep[maxn],size[maxn];
int sum[maxn],add[maxn],son[maxn],fa[maxn],a[maxn],id[maxn];
int top[maxn],wt[maxn];
void add_edge(int x,int y)
{
	tot++;u[tot]=x;v[tot]=y;
	nxt[tot]=head[x];head[x]=tot;
}
void build(int root,int l,int r)
{
	if(l==r)
	{
		sum[root]=wt[l];sum[root]%=mod;
		return;
	}
	int mid=(l+r)/2;
	build(root*2,l,mid);
	build(root*2+1,mid+1,r);
	sum[root]=sum[root*2]+sum[root*2+1];
}
void push_down(int root,int l,int r)
{
	if(add[root])
	{
		add[root*2]+=add[root];add[root*2+1]+=add[root];
		int mid=(l+r)/2;
		sum[root*2]+=1ll*(mid-l+1)*add[root]%mod;
		sum[root*2+1]+=1ll*(r-mid)*add[root]%mod;
		add[root]=0; 
	}
}
void update(int root,int l,int r,int L,int R,int k)
{
	if(L<=l&&r<=R)
	{
		add[root]+=k;sum[root]=(sum[root]+1ll*(r-l+1)*k%mod)%mod;
		return;
	}
	push_down(root,l,r);
	int mid=(l+r)/2;
	if(mid>=L) update(root*2,l,mid,L,R,k);
	if(mid+1<=R) update(root*2+1,mid+1,r,L,R,k);
	sum[root]=(sum[root*2]+sum[root*2+1])%mod; 
}
int query(int root,int l,int r,int L,int R)
{
	int ans=0;
	if(L<=l&&r<=R)
	{
		return sum[root]%mod;
	}
	push_down(root,l,r);
	int mid=(l+r)/2;
	if(mid>=L) ans+=query(root*2,l,mid,L,R);
	ans%=mod;
	if(mid+1<=R) ans+=query(root*2+1,mid+1,r,L,R);
	return ans%mod;
}
void dfs1(int x,int father,int deep)
{
	dep[x]=deep;fa[x]=father;size[x]=1;
	int maxson=-1;
	for(int i=head[x];i!=-1;i=nxt[i])
	{
		if(v[i]==father) continue;
		dfs1(v[i],x,deep+1);
		size[x]+=size[v[i]];	
		if(size[v[i]]>maxson) {maxson=size[v[i]];son[x]=v[i];}
	}
}
void dfs2(int x,int topf)
{
	id[x]=++cnt;
	top[x]=topf;
	wt[cnt]=a[x];
	if(!son[x]) return ;
	dfs2(son[x],topf);
	for(int i=head[x];i!=-1;i=nxt[i])
	{
		if(v[i]==fa[x]||v[i]==son[x]) continue;
		dfs2(v[i],v[i]);
	}
}
int qrange(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{	
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans+=query(1,1,n,id[top[x]],id[x]);
		ans%=mod;
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	ans+=query(1,1,n,id[x],id[y]);
	return ans%mod;
}
void uprange(int x,int y,int k)
{
	k%=mod;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		update(1,1,n,id[top[x]],id[x],k);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	update(1,1,n,id[x],id[y],k);	
} 
int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d%d%d",&n,&m,&root,&mod);
	for(int i=1;i<=n;i++) scanf("%d",&a[i]);
	for(int i=1;i<n;i++)
	{
		int x,y;scanf("%d%d",&x,&y);
		add_edge(x,y);add_edge(y,x);
	}
	dfs1(root,0,1);dfs2(root,root);  
	build(1,1,n);
	for(int i=1;i<=m;i++)
	{
		int k;scanf("%d",&k);
		if(k==1)
		{
			int x,y,z;scanf("%d%d%d",&x,&y,&z);
			uprange(x,y,z);
		}
		if(k==2)
		{
			int x,y;scanf("%d%d",&x,&y);
			printf("%d\n",qrange(x,y));
		}
		if(k==3)
		{
			int x,y;scanf("%d%d",&x,&y);
			update(1,1,n,id[x],id[x]+size[x]-1,y);
		}
		if(k==4)
		{
			int x;scanf("%d",&x);
			printf("%d\n",query(1,1,n,id[x],id[x]+size[x]-1));
		}
	}
	return 0;	
} 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章