題目地址:
http://www.lydsy.com/JudgeOnline/problem.php?id=4297
= =
題意:
給定一棵有n個點,m個葉子節點的樹,其中m個葉子節點分別爲1到m號點,每個葉子節點有一個權值r[i]。你需要給剩下n-m個點各指定一個權值,使得樹上相鄰兩個點的權值差的絕對值之和最小。
思路:
QAQ,看了Claris的代碼,又自己想了想,但還是有點迷迷糊糊。後來jxt看了這題,說了他的思路,自己才理解這題。然後下面說的是jxt的思路
首先題目給了m個權值確定的葉子節點,那麼答案的確定可以通過葉子節點,自底向上地完成。討論一個情況:當前結點爲y,x是y的子節點,x有
上述情況可以得出兩個結論:
- 當
|kbig−ksmall| 最小時,x的子樹貢獻的答案最小,爲最優,此時val[x] 的取值明顯是一個範圍,在這個範圍裏,|kbig−ksmall| 達到了最小值。 - 當
val[x] 大小加減1時,|val[y]−val[x]| 的取值變化始終是1,但當val[x] 沒有處於使|kbig−ksmall| 最小的範圍時,顯然:val[x] 大小加減1對|kbig−ksmall| 變化的影響始終大於等於1,則val[x] 與x子樹取值對答案的貢獻大於val[x] 與val[y] 對答案的貢獻。如果x及其子樹沒有達到最優,那麼當x達到最優時的答案一定比當前答案優秀。(即一個樹達到最優的條件是其所有子樹達到最優)
綜上,爲了解題,我們只需要從低往上,確定當前節點使其子樹節點最優的取值範圍,然後用這個取值範圍遞推出其父節點的取值範圍,就可以求出答案。不需要重新建圖。
代碼:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstdlib>
using namespace std;
#define PB push_back
#define MS(x, y) memset(x, y, sizeof(x))
typedef pair<int, int> P;
typedef long long LL;
const int MAXN = 5e5 + 5;
const LL INF = 1000000000000000000LL;
int n, m;
LL ans;
int l[MAXN], r[MAXN];
int fa[MAXN];
P a[MAXN << 1];
vector<int> edges[MAXN];
void dfs(int u, int fa) {
// cout << "dfs ing ..." << endl;
if (edges[u].size() == 1) return ;
for (int i = edges[u].size() - 1; i >= 0; --i) {
if (edges[u][i] == fa) continue;
dfs(edges[u][i], u);
}
int cnt = 0, v;
LL mn = INF, sum = 0, now;
int fut = 0, pst = 0;
// 確定u使其子樹達到最優的取值範圍
LL fut_sum = 0, pst_sum = 0;
for (int i = edges[u].size() - 1; i >= 0; --i) {
v = edges[u][i];
if (v == fa) continue;
a[cnt++] = P(l[v], 0);
a[cnt++] = P(r[v], 1);
++fut;
fut_sum += l[v];
}
sort(a, a + cnt);
for (int i = 0; i < cnt; ++i) {
if (a[i].second) {
++pst;
pst_sum += a[i].first;
} else {
--fut;
fut_sum -= a[i].first;
}
now = fut_sum - fut * a[i].first + pst * a[i].first - pst_sum;
if (now < mn) {
mn = now;
l[u] = a[i].first;
}
if (now == mn) r[u] = a[i].first;
}
ans += mn;
}
int main() {
while (~scanf("%d%d", &n, &m)) {
int u, v, lim;
ans = 0;
head = tail = 0;
MS(fa, 0);
MS(deg, 0);
MS(used, false);
for (int i = 1; i <= n; ++i) edges[i].clear();
for (int i = 1; i < n; ++i) {
scanf("%d%d", &u, &v);
edges[u].PB(v);
edges[v].PB(u);
}
for (int i = 1; i <= m; ++i) {
scanf("%d", l + i);
r[i] = l[i];
}
if (n == m) {
for (u = 1; u <= n; ++u) {
for (int i = edges[u].size() - 1; i >= 0; --i) {
v = edges[u][i];
ans += abs(l[u] - l[v]);
}
}
printf("%I64d\n", ans / 2);
continue;
}
dfs(n, 0);
printf("%I64d\n", ans);
}
}