樹上啓發式合併 dsu on tree

dsu on tree 用來解決樹上問題。可以在 O(nlogn)O(n \log n) 中完成對靜態的子樹統計。但是,不支持修改,只能對子樹統計,不能鏈上統計。

我們來看一個問題。有一棵樹,每個點有一個權值。求這棵樹的每一棵子樹的衆數權值之和,如果有多個衆數那麼都要統計。先考慮 O(n2)O(n^2) 的暴力,對於每一棵子樹,遍歷這棵子樹的所有點,用一個桶記錄每一個數出現的次數,統計一下衆數之和,然後清空桶消除影響。

void add(int x, int fa, int val) {
    cnt[col[x]] += val;
    if(cnt[col[x]] > mx) mx = cnt[col[x]], sum = col[x];
    else if(cnt[col[x]] == mx) sum += col[x];
    for(int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if (y != fa) add(y, x, val);
    }
}
void dfs(int x, int fa) {
    for(int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if(y != fa) dfs(y, x);
    }
    add(x, fa, 1); ans[x] = sum;
    add(x, fa, -1), sum = 0, mx = 0;
}

可以發現,最後一個搜到的兒子是沒有必要消除影響的,因爲消除影響佔用時間最多的兒子是重兒子,所以可以在暴力上加了一個不消除重兒子影響的優化。這是 dsu on tree 的核心思想,dsu on tree 的流程如下,有先後順序。

  1. 遍歷每一個節點
  2. 遞歸解決所有的輕兒子,同時消除遞歸產生的影響
  3. 遞歸重兒子,不消除遞歸的影響
  4. 暴力統計所有輕兒子對答案的影響
  5. 更新該節點的答案
  6. 暴力刪除所有輕兒子對答案的影響

只加了一個不消除重兒子影響的優化。其他都是暴力,好像還是 O(n)O(n) 的,其實不然。因爲一個節點到根的路徑上重鏈和輕鏈個數不會超過 logn\log n 條,只有 dfs 到輕邊時,纔會將輕兒子的子樹中合併到上一級的重鏈,那麼每一個點最多向上合併 logn\log n 次,整體複雜度 O(nlogn)O(n \log n)

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define mp make_pair
#define lson (p << 1)
#define rson (p << 1 | 1)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5, M = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
inline void write(int x){
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
int n, val[N], cnt[N], mx;
ll sum, ans[N];
struct edge{
    int to, nxt;
}e[M];
int head[N], tot;
void addedge(int x, int y){
    e[++tot].to = y, e[tot].nxt = head[x], head[x] = tot;
}
int sz[N], son[N];
void dfs1(int x, int fa){
    sz[x] = 1;
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa){
            dfs1(y, x); sz[x] += sz[y];
            if (sz[y] > sz[son[x]]) son[x] = y;
        }
    }
}
bool vis[N];
void add(int x, int fa, int k){
    cnt[val[x]] += k;
    if (k > 0 && cnt[val[x]] > mx) sum = val[x], mx = cnt[val[x]];
    else if (k > 0 && cnt[val[x]] == mx) sum += val[x];
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa && !vis[y]) add(y, x, k);
    }
}
void dfs2(int x, int fa, int flg){
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa && y != son[x]) dfs2(y, x, 0);
    }
    if (son[x]) dfs2(son[x], x, 1), vis[son[x]] = 1;
    add(x, fa, 1); ans[x] = sum;
    if (son[x]) vis[son[x]] = 0; 
    if (!flg) add(x, fa, -1), sum = mx = 0;
}
int main() {
    n = read();
    for (int i = 1; i <= n; i++) val[i] = read();
    for (int i = 1; i < n; i++){
        int x = read(), y = read();
        addedge(x, y); addedge(y, x);
    }
    dfs1(1, 0); dfs2(1, 0, 0);
    for (int i = 1; i <= n; i++) printf("%lld ", ans[i]);
    return 0;
}
發佈了40 篇原創文章 · 獲贊 39 · 訪問量 1601
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章