重心分解(Centroid Decomposition)
首先要找到樹的中心,樹的重心的定義是:刪除該中心結點得到的最大子樹的頂點數最少的頂點就是樹的重心。運用dfs的方法很容易實現,代碼爲:
void get_hvy(int u, int fa){
siz[u] = 1, maxx[u] = 0;
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(v != fa && !vis[v]){
get_hvy(v, u);
siz[u] += siz[v];
maxx[u] = max(maxx[u], siz[v]);
}
}
maxx[u] = max(maxx[u], S-siz[u]);
if(!hvy || maxx[hvy] > maxx[u]) hvy = u;
}
對於題
根據重心把樹分解爲若干子樹,那麼所要求的頂點對(計算樹上頂點對的距離)則分爲以下三類:
- 頂點屬於同一子樹的頂點對
- 頂點不屬於同一子樹的頂點對
- 頂點和其他頂點組成的頂點對
那麼通過當前重心求的到其他頂點的距離中包括第一種情況,但是這樣會造成重複統計,因此我們需要在對應子樹下面重新計算一邊並進行去重操作。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1e4+100;
struct Edge{
int to, val, next;
}e[maxn<<2];
int head[maxn], tot;
int siz[maxn], maxx[maxn], vis[maxn], hvy, S;
int dep[maxn], now[maxn], cnt, ans;
int n, k;
void addedge(int from, int to, int val){
e[tot].to = to;
e[tot].val = val;
e[tot].next = head[from];
head[from] = tot++;
}
void get_hvy(int u, int fa){
siz[u] = 1;
maxx[u] = 0;
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(v != fa && !vis[v]){
get_hvy(v, u);
siz[u] += siz[v];
maxx[u] = max(maxx[u], siz[v]);
}
}
maxx[u] = max(maxx[u], S-siz[u]);
if(!hvy || maxx[u] < maxx[hvy]) hvy = u;
}
void get_dep(int u, int fa){
now[cnt++] = dep[u];
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(v != fa && !vis[v]){
dep[v] = dep[u] + e[i].val;
get_dep(v, u);
}
}
}
int get_sum(int u, int dst){
dep[u] = dst; cnt = 0;
get_dep(u, -1);
sort(now, now+cnt);
int res = 0;
for(int l = 0, r = cnt-1; l < r; l++){
while(l < r && now[l] + now[r] > k) r--;
res += max(0, r-l);
}
return res;
}
void get_ans(int u, int fa){
vis[u] = 1;
ans += get_sum(u, 0);
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(!vis[v] && v != fa){
ans -= get_sum(v, e[i].val);
hvy = 0, S = siz[v];
get_hvy(v, u);
get_ans(hvy, -1);
}
}
}
int main(){
while(scanf("%d%d", &n, &k) != EOF){
if(!n && !k) break;
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
tot = 1;
for(int i = 1; i < n; i++){
int from, to, val; scanf("%d%d%d", &from, &to, &val);
addedge(from, to, val);
addedge(to, from, val);
}
hvy = 0; S = n; ans = 0;
get_hvy(1, -1);
get_ans(hvy, -1);
printf("%d\n", ans);
}
return 0;
}
對於洛谷OJ P3806 【模板】點分治1
這道題同樣需要對重心分解,但是對於當前重心的所有子樹,,當到子樹時,到的距離全部用數組judge以下標的形式存儲下來,rem數組則存儲了到的距離,這時遍歷m次需要判斷judge[query[i]-rem[j]]是否存在,如果存在,那麼路徑長度爲query[i]的長度存在的。具體看代碼:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1e4+10;
struct Edge{
int to, val, next;
}e[maxn<<1];
int query[maxn], ret[maxn];
int head[maxn], tot;
int siz[maxn], maxx[maxn], hvy, S;
int vis[maxn], rem[maxn], dist[maxn], judge[10000007];
int n, m, q[maxn];
void addedge(int from, int to, int val){
e[tot].to = to;
e[tot].val = val;
e[tot].next = head[from];
head[from] = tot++;
}
// find grativity
void get_hvy(int u, int fa){
siz[u] = 1, maxx[u] = 0;
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(v != fa && !vis[v]){
get_hvy(v, u);
siz[u] += siz[v];
maxx[u] = max(maxx[u], siz[v]);
}
}
maxx[u] = max(maxx[u], S-siz[u]);
if(!hvy || maxx[hvy] > maxx[u]) hvy = u;
}
void get_dist(int u, int fa){
rem[++rem[0]] = dist[u];
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(!vis[v] && v != fa){
dist[v] = dist[u] + e[i].val;
get_dist(v, u);
}
}
}
void get_ans(int u){
int p = 0;
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(vis[v]) continue;
rem[0] = 0, dist[v] = e[i].val;
get_dist(v, u); // 處理u的每個子樹
for(int j = rem[0]; j; --j)
for(int k = 1; k <= m; k++)
if(query[k] >= rem[j]) ret[k] |= judge[query[k]-rem[j]];
for(int j = rem[0]; j; --j)
q[++p] = rem[j], judge[rem[j]] = 1;
}
for(int i = 1; i <= p; i++)
judge[q[i]] = 0;
}
void cal(int u){
vis[u] = judge[0] = 1; get_ans(u);
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to;
if(vis[v]) continue;
S = siz[v], hvy = 0;
get_hvy(v, u); cal(hvy);
}
}
int main(){
scanf("%d%d", &n, &m);
memset(head, -1, sizeof(head)); tot = 0;
memset(vis, 0, sizeof(vis));
for(int i = 1; i < n; i++){
int from, to, val; scanf("%d%d%d", &from, &to, &val);
addedge(from, to, val);
addedge(to, from, val);
}
for(int i = 1; i <= m; i++) scanf("%d", &query[i]);
S = n; hvy = 0;
get_hvy(1, -1); cal(hvy);
for(int i = 1; i <= m; i++)
if(ret[i]) puts("AYE");
else puts("NAY");
return 0;
}