題目鏈接:https://www.luogu.org/problemnew/show/P3384
樹鏈剖分的核心思想就是將一個樹上問題轉化爲鏈上問題,然後就可以用線段樹解決啦。
這裏有一個重鏈和輕鏈的概念。找到每個節點的重兒子作爲他的son存下來。將每個點和他的重兒子們作爲一條鏈搞下去。
#include <bits/stdc++.h>
using namespace std;
const int maxn=400005;
int n,m,root,mod;
int tot,cnt;
int u[maxn],v[maxn],head[maxn],nxt[maxn],dep[maxn],size[maxn];
int sum[maxn],add[maxn],son[maxn],fa[maxn],a[maxn],id[maxn];
int top[maxn],wt[maxn];
void add_edge(int x,int y)
{
tot++;u[tot]=x;v[tot]=y;
nxt[tot]=head[x];head[x]=tot;
}
void build(int root,int l,int r)
{
if(l==r)
{
sum[root]=wt[l];sum[root]%=mod;
return;
}
int mid=(l+r)/2;
build(root*2,l,mid);
build(root*2+1,mid+1,r);
sum[root]=sum[root*2]+sum[root*2+1];
}
void push_down(int root,int l,int r)
{
if(add[root])
{
add[root*2]+=add[root];add[root*2+1]+=add[root];
int mid=(l+r)/2;
sum[root*2]+=1ll*(mid-l+1)*add[root]%mod;
sum[root*2+1]+=1ll*(r-mid)*add[root]%mod;
add[root]=0;
}
}
void update(int root,int l,int r,int L,int R,int k)
{
if(L<=l&&r<=R)
{
add[root]+=k;sum[root]=(sum[root]+1ll*(r-l+1)*k%mod)%mod;
return;
}
push_down(root,l,r);
int mid=(l+r)/2;
if(mid>=L) update(root*2,l,mid,L,R,k);
if(mid+1<=R) update(root*2+1,mid+1,r,L,R,k);
sum[root]=(sum[root*2]+sum[root*2+1])%mod;
}
int query(int root,int l,int r,int L,int R)
{
int ans=0;
if(L<=l&&r<=R)
{
return sum[root]%mod;
}
push_down(root,l,r);
int mid=(l+r)/2;
if(mid>=L) ans+=query(root*2,l,mid,L,R);
ans%=mod;
if(mid+1<=R) ans+=query(root*2+1,mid+1,r,L,R);
return ans%mod;
}
void dfs1(int x,int father,int deep)
{
dep[x]=deep;fa[x]=father;size[x]=1;
int maxson=-1;
for(int i=head[x];i!=-1;i=nxt[i])
{
if(v[i]==father) continue;
dfs1(v[i],x,deep+1);
size[x]+=size[v[i]];
if(size[v[i]]>maxson) {maxson=size[v[i]];son[x]=v[i];}
}
}
void dfs2(int x,int topf)
{
id[x]=++cnt;
top[x]=topf;
wt[cnt]=a[x];
if(!son[x]) return ;
dfs2(son[x],topf);
for(int i=head[x];i!=-1;i=nxt[i])
{
if(v[i]==fa[x]||v[i]==son[x]) continue;
dfs2(v[i],v[i]);
}
}
int qrange(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=query(1,1,n,id[top[x]],id[x]);
ans%=mod;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query(1,1,n,id[x],id[y]);
return ans%mod;
}
void uprange(int x,int y,int k)
{
k%=mod;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,id[x],id[y],k);
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d%d%d",&n,&m,&root,&mod);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<n;i++)
{
int x,y;scanf("%d%d",&x,&y);
add_edge(x,y);add_edge(y,x);
}
dfs1(root,0,1);dfs2(root,root);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int k;scanf("%d",&k);
if(k==1)
{
int x,y,z;scanf("%d%d%d",&x,&y,&z);
uprange(x,y,z);
}
if(k==2)
{
int x,y;scanf("%d%d",&x,&y);
printf("%d\n",qrange(x,y));
}
if(k==3)
{
int x,y;scanf("%d%d",&x,&y);
update(1,1,n,id[x],id[x]+size[x]-1,y);
}
if(k==4)
{
int x;scanf("%d",&x);
printf("%d\n",query(1,1,n,id[x],id[x]+size[x]-1));
}
}
return 0;
}