題意:給定一棵樹, 然後加一條邊, 有若干詢問, 問你每一個詢問(u,v), 加了這條邊後可以從u到v節省多少距離。
思路: 一共三種情況, 1, 原路
2,u - x - y - v
3,u - y - x - v
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100010;
struct node{
int to, w, next;
}edge[MAXN*2];
int tot,head[MAXN];
void init(){
tot = 0;memset(head, -1, sizeof(head));
}
void add_edge(int u, int v, int w){
edge[tot].to = v;
edge[tot].w = w;
edge[tot].next = head[u];
head[u] = tot++;
}
//LCA部分
int rmq[2*MAXN];//rmq數組,就是歐拉序列對應的深度序列
struct ST
{
int mm[2*MAXN];
int dp[2*MAXN][20];//最小值對應的下標
void init(int n)
{
mm[0] = -1;
for(int i = 1;i <= n;i++)
{
mm[i] = ((i&(i-1)) == 0)?mm[i-1]+1:mm[i-1];
dp[i][0] = i;
}
for(int j = 1; j <= mm[n];j++)
for(int i = 1; i + (1<<j) - 1 <= n; i++)
dp[i][j] = rmq[dp[i][j-1]] < rmq[dp[i+(1<<(j-1))][j-1]]?dp[i][j-1]:dp[i+(1<<(j-1))][j-1];
}
int query(int a,int b)//查詢[a,b]之間最小值的下標
{
if(a > b)swap(a,b);
int k = mm[b-a+1];
return rmq[dp[a][k]] <= rmq[dp[b-(1<<k)+1][k]]?dp[a][k]:dp[b-(1<<k)+1][k];
}
};
int F[MAXN*2];//歐拉序列,就是dfs遍歷的順序,長度爲2*n-1,下標從1開始
int P[MAXN];//P[i]表示點i在F中第一次出現的位置
int cnt;
int dis[MAXN];
ST st;
void dfs(int u,int pre,int dep, int d)
{
dis[u] = d;
F[++cnt] = u;
rmq[cnt] = dep;
P[u] = cnt;
for(int i = head[u];i != -1;i = edge[i].next)
{
int v = edge[i].to;
if(v == pre)continue;
dfs(v,u,dep+1, edge[i].w+d);
F[++cnt] = u;
rmq[cnt] = dep;
}
}
void LCA_init(int root,int node_num)//查詢LCA前的初始化
{
cnt = 0;
dfs(root,root,0, 0);
st.init(2*node_num-1);
}
int query_lca(int u,int v)//查詢u,v的lca編號
{
return F[st.query(P[u],P[v])];
}
int calc(int u, int v){
int LCA = query_lca(u, v);
return dis[u] + dis[v] - 2*dis[LCA];
}
int main(){
int T;
cin>>T;
int icase = 0;
while(T--){
init();
printf("Case #%d:\n", ++icase);
int n,q;
scanf("%d %d", &n, &q);
for(int i=1; i<n; i++){
int u,v,w;
scanf("%d %d %d", &u, &v, &w);
add_edge(u, v, w), add_edge(v, u, w);
}
LCA_init(1, n);
int u,v,w;
int lca = query_lca(u, v);
scanf("%d %d %d", &u, &v, &w);
while(q--){
int x,y;
scanf("%d %d", &x, &y);
int sum1 = calc(x, y);
int sum2 = calc(x, u) + w + calc(y, v);
int sum3 = calc(y, u) + w + calc(x, v);
if(sum2 >= sum1 && sum3>=sum1) printf("0\n");
else printf("%d\n", sum1 - min(sum2, sum3));
}
}
return 0;
}