題目
。
題解
另有的解法。
我寫的是虛樹的做法。對於每種顏色建虛樹,然後在虛樹上求一次子樹size,然後每個虛樹上的點對原樹上對應點做初步貢獻。
最後再一遍求答案。
存答案我用的是結構體,存最大值和編號之和。可以簡單合併。
考慮證明上面的做法對於一種顏色,在初步貢獻和最後的一遍不會重複計算。
因爲虛樹上的點要麼是關鍵點的,要麼就是關鍵點。也就是說虛樹點一定滿足要麼自己爲關鍵點,要麼有多個子樹都有關鍵點。
那麼虛樹點的一定大於每一個子樹的(原樹上的子樹)。(這裏的指的是子樹內當前顏色出現次數 也就是 子樹內關鍵點的數量)
本質上就是虛樹上不會出現只有一個兒子的非關鍵點。
那麼合併一定不會有問題。
時間複雜度
#include <bits/stdc++.h>
using namespace std;
template<class T>inline void read(T &res) {
char ch; while(!isdigit(ch=getchar()));
for(res=ch-'0';isdigit(ch=getchar());res=res*10+ch-'0');
}
#define pb push_back
#define pii pair<int,int>
typedef long long LL;
const int MAXN = 100005;
vector<int>e[MAXN], g[MAXN], vec[MAXN];
int n, col[MAXN], dfn[MAXN], tmr, dep[MAXN], fa[MAXN], sz[MAXN], son[MAXN], top[MAXN];
void dfs1(int u, int ff) {
dep[u] = dep[fa[u] = ff] + (sz[u] = 1);
for(auto v : e[u])
if(v != ff) {
dfs1(v, u); sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
top[u] = tp; dfn[u] = ++tmr;
if(son[u]) dfs2(son[u], tp);
for(auto v : e[u])
if(v != fa[u] && v != son[u]) dfs2(v, v);
}
inline int Lca(int u, int v) {
while(top[u] != top[v]) {
if(dep[top[u]] > dep[top[v]]) u = fa[top[u]];
else v = fa[top[v]];
}
return dep[u] > dep[v] ? v : u;
}
inline bool cmp(int i, int j) { return dfn[i] < dfn[j]; }
bool flg[MAXN];
int siz[MAXN], stk[MAXN], indx;
void ins(int x) {
if(x == stk[indx]) return;
if(!indx) { stk[++indx] = x; return; }
int lca = Lca(x, stk[indx]);
if(lca == stk[indx]) { stk[++indx] = x; return; }
while(indx>1 && dfn[stk[indx-1]] >= dfn[lca]) g[stk[indx-1]].pb(stk[indx]), --indx;
if(lca != stk[indx]) g[lca].pb(stk[indx]), stk[indx] = lca;
stk[++indx] = x;
}
struct node {
int mx; LL sum;
node(int mx=0, LL sum=0):mx(mx), sum(sum){}
inline node operator +(const node &o)const {
return mx > o.mx ? *this : mx < o.mx ? o : node(mx, sum + o.sum);
}
}f[MAXN];
void dfs(int u, int clr) {
siz[u] = flg[u];
for(auto v : g[u]) dfs(v, clr), siz[u] += siz[v];
f[u] = f[u] + node(siz[u], clr); g[u].clear();
}
void getans(int u, int ff) {
for(auto v : e[u]) if(v != ff)
getans(v, u), f[u] = f[u] + f[v];
}
int main () {
read(n);
for(int i = 1; i <= n; ++i) read(col[i]), vec[col[i]].pb(i);
for(int i = 1, x, y; i < n; ++i) read(x), read(y), e[x].pb(y), e[y].pb(x);
dfs1(1, 0), dfs2(1, 1);
for(int i = 1; i <= n; ++i) if(vec[i].size()) {
sort(vec[i].begin(), vec[i].end(), cmp);
indx = 0;
for(auto x : vec[i]) flg[x] = 1, ins(x);
while(indx > 1) g[stk[indx-1]].pb(stk[indx]), --indx;
dfs(stk[1], i);
for(auto x : vec[i]) flg[x] = 0;
}
getans(1, 0);
for(int i = 1; i <= n; ++i) printf("%lld ", f[i].sum);
}