題意
題解
這裏是一種時空複雜度均爲的暴力做法。感覺這道題順着思路想下去還是挺簡單的。
題目的要求實際上是對於同種顏色的點,都必須在同一條鏈上。
那麼我們自然想到把每一種顏色單獨處理,先找到每種顏色中深度最大的那個點,顯然如果這種顏色能符合一條鏈的條件,那麼這個最深的點一定是鏈的一端,我們設它爲。
接着我們考慮另一端的位置。這裏出現了三種情況(如果不存在這種顏色那麼直接爲0,下面不再討論):
- 只有一個點是這種顏色,即另一端就是。
- 另一端是的祖先。
- 另一端不是的祖先。
其實普遍狀況是第三種(需要優先考慮),前兩種都是符合題目要求的條件的,可以直接計算(具體見代碼,第一種我通過數組預處理了)。
如何判斷是否存在第三種情況呢?其實很簡單,我們在這種顏色中尋找不是的祖先的點中深度最大的點即可。
如果我們發現存在第三種情況,那麼我們就要嘗試驗證它是否符合題目要求的條件。具體而言,設另一端爲,我們就是想要知道x->y這條鏈上是否包含了所有顏色與它們相同的點,也就是要求出一條鏈上有多少點是這種顏色的。
那麼就可以想到使用樹上差分的思想。我們求出每個點到根節點每種顏色分別有幾個點。然後把和的答案相加再減去的答案即可(處特判是否是當前顏色)。
然後問題轉化爲求一個點到根節點的路徑上包含的每種顏色的點的數量。我們可以想到使用主席樹來維護。每個節點在其父親的基礎上進行修改,這樣就可以在複雜度內完成統計。
代碼
#include <bits/stdc++.h>
#define MAX 2000005
#define MAXM 25000005
#define ll long long
#define mid ((l+r)>>1)
using namespace std;
template<typename T>
void read(T &n){
n = 0;
T f = 1;
char c = getchar();
while(!isdigit(c) && c != '-') c = getchar();
if(c == '-') f = -1, c = getchar();
while(isdigit(c)) n = n*10+c-'0', c = getchar();
n *= f;
}
template<typename T>
void write(T n){
if(n < 0) putchar('-'), n = -n;
if(n > 9) write(n/10);
putchar(n%10+'0');
}
int n, cnt, tot;
int head[MAX], vet[MAX*2], Next[MAX*2];
int col[MAX], sz[MAX], dep[MAX], f[MAX][21];
vector<int> v[MAX];
void add(int x, int y){
cnt++;
Next[cnt] = head[x];
head[x] = cnt;
vet[cnt] = y;
}
int s[MAXM], lc[MAXM], rc[MAXM], rt[MAX];
void build(int &p, int l, int r){
if(!p) p = ++tot;
if(l == r) return;
build(lc[p], l, mid);
build(rc[p], mid+1, r);
}
void update(int &p, int l, int r, int x, int last){
p = ++tot;
lc[p] = lc[last], rc[p] = rc[last], s[p] = s[last]+1;
if(l == r) return;
if(mid >= x) update(lc[p], l, mid, x, lc[last]);
else update(rc[p], mid+1, r, x, rc[last]);
}
int query(int l, int r, int p, int x, int y, int z){ //z:lca
if(l == r) return s[x]+s[y]-2*s[z];
if(mid >= p) return query(l, mid, p, lc[x], lc[y], lc[z]);
else return query(mid+1, r, p, rc[x], rc[y], rc[z]);
}
void dfs(int x, int fa){
dep[x] = dep[fa]+1, sz[x] = 1;
f[x][0] = fa;
for(int i = 1; i <= 20; i++) f[x][i] = f[f[x][i-1]][i-1];
update(rt[x], 1, n, col[x], rt[fa]);
for(int i = head[x]; i; i = Next[i]){
int v = vet[i];
if(v == fa) continue;
dfs(v, x);
sz[x] += sz[v];
}
}
int Lca(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--){
if(dep[f[x][i]] >= dep[y]) x = f[x][i];
}
if(x == y) return x;
for(int i = 20; i >= 0; i--){
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
}
return f[x][0];
}
int get(int x, int fa){
for(int i = 20; i >= 0; i--){
if(dep[f[x][i]] > dep[fa]) x = f[x][i];
}
return x;
}
bool cmp(int a, int b){
return dep[a] > dep[b];
}
ll pre[MAX];
int main()
{
cin >> n;
for(int i = 1; i <= n; i++){
read(col[i]);
v[col[i]].push_back(i);
}
int x, y, lca = 0;
for(int i = 1; i < n; i++){
read(x), read(y);
add(x, y), add(y, x);
}
build(rt[0], 1, n);
dfs(1, 0);
for(int i = 1; i <= n; i++){
for(int j = head[i]; j; j = Next[j]){
int v = vet[j];
if(v == f[i][0]) continue;
pre[i] += (ll)sz[v]*(n-sz[v]-1);
}
pre[i] += (ll)(n-sz[i])*(sz[i]-1);
pre[i] /= 2;
pre[i] += n-1;
}
for(int i = 1; i <= n; i++){
if(v[i].empty()){
write((ll)n*(n-1)/2), puts("");
continue;
}
sort(v[i].begin(), v[i].end(), cmp);
x = v[i][0];
if(v[i].size() == 1){
write(pre[x]), puts("");
continue;
}
y = v[i][v[i].size()-1];
for(int j = 1; j < v[i].size(); j++){
lca = Lca(x, v[i][j]);
if(lca != v[i][j]){
y = v[i][j];
break;
}
}
if(y != lca){
int t = query(1, n, i, rt[x], rt[y], rt[lca]);
if(col[lca] == i) t++;
if(t == v[i].size()){
write((ll)sz[x]*sz[y]), puts("");
}
else puts("0");
}
else{
write((ll)sz[x]*(n-sz[get(x, lca)])), puts("");
}
}
return 0;
}