知識儲備:線段樹、圖的存儲與遍歷、狀態壓縮
題意翻譯
你有一棵以1爲根的有根樹,有n個點,每個節點初始有一個顏色c[i]。
有兩種操作:
1 v c 將以v爲根的子樹中所有點顏色更改爲c
2 v 查詢以v爲根的子樹中的節點有多少種不同的顏色
題目鏈接:https://www.luogu.org/problem/CF620E
首先我們講一下dfs序在這道題裏面的應用
定義:一棵樹從根節點開始進行深度優先搜索,用一個時間戳記錄下來每一個點被訪問的時間,得到的序列就叫dfs序。
注意:以不同方式存樹會有不同的dfs序,但是對於同一個鄰接表搜出來的dfs序永遠一樣
這棵樹的dfs序(不唯一)就是:1 4 6 3 7 10 5 8 2 9
這樣的話我們就可以把一個點的dfs序代表這個點,這樣不論樹的形狀是怎樣的,dfs序都可以把它轉化成線性結構
我們通過dfs把這顆樹的dfs序存儲在pos數組中
同時,我們還要記錄一個點的入點時間戳與出點時間戳。因爲題目要求我們支持更改一整個子樹,如果我們把它轉化成線性結構就可以用線段樹維護它而不用dfs了,但是線段樹要有左右端點,而這個左右端點就是用in與out數組實現。
dfs代碼:
我是用鄰接表存圖,tot爲時間,pos記錄當前時間訪問到的結點。
void dfs(int x,int fa){//用dfs序將樹形結構轉爲線性結構
tot++;
in[x]=tot;
pos[tot]=x;
for(int i=head[x];i+1;i=e[i].next){
int k=e[i].to;
if(k==fa)continue;
dfs(k,x);
}
out[x]=tot;//其實我們out記錄的是該子樹中dfs序最大結點的dfs序,所以tot不加一
return;
}
這樣的話,我們就可以用pos數組建樹了!
struct tree{
long long sum;
long long tag;//記錄要修改爲哪個顏色
int l,r;
}t[1600040];
我們把顏色狀態壓縮到一個long long中,然後要建立一顆與運算的線段樹:
首先是上傳、下傳以及建樹:
void pushup(int rt){
t[rt].sum=t[rt<<1].sum|t[rt<<1|1].sum;
return;
}
void pushdown(int rt){
if(t[rt].tag!=0){//或運算不像區間和,這裏0是影響答案的
t[rt<<1].sum=t[rt].tag;
t[rt<<1].tag=t[rt].tag;
t[rt<<1|1].sum=t[rt].tag;
t[rt<<1|1].tag=t[rt].tag;
t[rt].tag=0;
}
return;
}
void build(int rt,int l,int r){
t[rt].l=l;
t[rt].r=r;
if(l==r){
t[rt].sum=(long long)1<<(c[pos[l]]);//
t[rt].tag=0;
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
return;
}
然後,我們分別來處理兩種操作
首先使修改操作:
將以x結點爲根的子樹的顏色修改爲p:
void modify(int rt,int l,int r,int p){//將dfs序從l到r這一區間的顏色改爲p 注意狀壓
if(t[rt].l>=l&&t[rt].r<=r){
t[rt].sum=(long long)1<<p;//
t[rt].tag=(long long)1<<p;//
return ;
}
int mid=(t[rt].l+t[rt].r)>>1;
pushdown(rt);
if(l<=mid)modify(rt<<1,l,r,p);
if(r>mid)modify(rt<<1|1,l,r,p);
pushup(rt);
return;
}
還記得我一開始的dfs過程嗎?我們每訪問一個點就記錄它的入點時間戳,然後遍歷它每一個兒子,然後記錄出點時間戳,這樣的話一個點的入、出點時間戳之間就包含了它的所有孩子。(這有點像笛卡爾樹中序遍歷就是一段區間)。
所以l和r就分別是in[x]和out[x]
然後是查詢操作:
long long query(int rt,int l,int r){//查詢dfs序從l到r的顏色個數
if(t[rt].l>=l&&t[rt].r<=r){
return t[rt].sum;
}
pushdown(rt);
int mid=(t[rt].l+t[rt].r)>>1;
long long rec=0;
if(l<=mid)rec|=query(rt<<1,l,r);
if(r>mid)rec|=query(rt<<1|1,l,r);
pushup(rt);
return rec;
}
但要注意的是,我們返回的是一個狀壓之後的數字,所以我們還要分解它,這裏有兩種寫法:
第一種比較直觀,但有一些慢
while(res) {
s += res&1;
res >>= 1;
}
實際思想就是按位與記錄答案
還有一種是用lowbit,沒學過樹狀數組的話可以先去看一下樹狀數組。
返回某一個數二進制下第一個1所代表的值,注意:是值,不是位置
long long lowbit(long long x){
return x&(-x);
}
for(long long j=num;j>0;j-=lowbit(j))ans++;
然後接結束了!
下附AC代碼:
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
int read(){
char s;
int x=0,f=1;
s=getchar();
while(s<'0'||s>'9'){
if(s=='-')f=-1;
s=getchar();
}
while(s>='0'&&s<='9'){
x*=10;
x+=s-'0';
s=getchar();
}
return x*f;
}
struct tree{
long long sum;
long long tag;//記錄要修改爲哪個顏色
int l,r;
}t[1600040];
int c[400040];
int tot=0;
int pos[400040];//時間戳所對應的點
int in[400040];//入點時間戳
int out[400040];//出點時間戳
struct edge{
int to,next;
}e[800080];
int eid=0;
int head[400040];
void insert(int u,int v){
eid++;
e[eid].to=v;
e[eid].next=head[u];
head[u]=eid;
}
void dfs(int x,int fa){//用dfs序將樹形結構轉爲線性結構
tot++;
in[x]=tot;
pos[tot]=x;
for(int i=head[x];i+1;i=e[i].next){
int k=e[i].to;
if(k==fa)continue;
dfs(k,x);
}
out[x]=tot;//其實我們out記錄的是該子樹中dfs序最大結點的dfs序,所以tot不加一
return;
}
void pushup(int rt){
t[rt].sum=t[rt<<1].sum|t[rt<<1|1].sum;
return;
}
void pushdown(int rt){
if(t[rt].tag!=0){//或運算不像區間和,這裏0是影響答案的
t[rt<<1].sum=t[rt].tag;
t[rt<<1].tag=t[rt].tag;
t[rt<<1|1].sum=t[rt].tag;
t[rt<<1|1].tag=t[rt].tag;
t[rt].tag=0;
}
return;
}
void build(int rt,int l,int r){
t[rt].l=l;
t[rt].r=r;
if(l==r){
t[rt].sum=(long long)1<<(c[pos[l]]);//
t[rt].tag=0;
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
return;
}
void modify(int rt,int l,int r,int p){//將dfs序從l到r這一區間的顏色改爲p 注意狀壓
if(t[rt].l>=l&&t[rt].r<=r){
t[rt].sum=(long long)1<<p;//
t[rt].tag=(long long)1<<p;//
return ;
}
int mid=(t[rt].l+t[rt].r)>>1;
pushdown(rt);
if(l<=mid)modify(rt<<1,l,r,p);
if(r>mid)modify(rt<<1|1,l,r,p);
pushup(rt);
return;
}
long long query(int rt,int l,int r){//查詢dfs序從l到r的顏色個數
if(t[rt].l>=l&&t[rt].r<=r){
return t[rt].sum;
}
pushdown(rt);
int mid=(t[rt].l+t[rt].r)>>1;
long long rec=0;
if(l<=mid)rec|=query(rt<<1,l,r);
if(r>mid)rec|=query(rt<<1|1,l,r);
pushup(rt);
return rec;
}
int n,m;
long long lowbit(long long x){
return x&(-x);
}
int main(){
memset(head,-1,sizeof(head));
n=read();
m=read();
for(int i=1;i<=n;i++){
c[i]=read();
}
for(int i=1;i<n;i++){
int x,y;
x=read();
y=read();
insert(x,y);
insert(y,x);
}
dfs(1,0);
build(1,1,n);
for(int i=1;i<=m;i++){
int a;
a=read();
if(a==1){
int x,y;
x=read();
y=read();
modify(1,in[x],out[x],y);
//cout<<query(1,in[x],in[x])<<endl;
}
else{
int x;
x=read();
long long num=query(1,in[x],out[x]);
int ans=0;
//cout<<num<<endl;
for(long long j=num;j>0;j-=lowbit(j))ans++;
cout<<ans<<endl;
}
}
return 0;
}