主要利用了LCA 的性質——一個區間的LCA等於區間內相鄰點的深度最低的LCA。
所以這個問題就轉化成選取k個相鄰點的深度 或 單點的深度。
#include <bits/stdc++.h>
using namespace std;
#define N 300005
#define go(i,a,b) for(int i=(a);i<=(b);i++)
#define inf 0x3f3f3f3f
struct no{
int to,n;
};no eg[N*2];
int tot,cnt,h[N],F[N*2],rmq[N*2],dep[N],P[N],a[N],n,k,u,to;
void add(int u,int to){
eg[++tot]={to,h[u]};h[u]=tot;
eg[++tot]={u,h[to]};h[to]=tot;
}
struct ST
{
int mm[2*N];
int dp[2*N][20];
void init(int n)
{
mm[0]=-1;
go(i,1,n)
{
dp[i][0]=i;
mm[i]=((i&(i-1))==0)?mm[i-1]+1:mm[i-1];
}
go(j,1,mm[n])
for(int i=1; i+(1<<j)-1<=n; i++)
dp[i][j]=rmq[dp[i][j-1]]<rmq[dp[i+(1<<(j-1))][j-1]]?dp[i][j-1]:dp[i+(1<<(j-1))][j-1];
}
int query(int a,int b)
{
if(a>b)swap(a,b);
int k=mm[b-a+1];
return rmq[dp[a][k]]<=rmq[dp[b-(1<<k)+1][k]]?dp[a][k]:dp[b-(1<<k)+1][k];
}
} st;
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
F[++cnt]=u;
rmq[cnt]=dep[u];
P[u]=cnt;
for(int i=h[u];i;i=eg[i].n){
int to=eg[i].to;
if(to==fa)continue;
dfs(to,u);
F[++cnt]=u;
rmq[cnt]=dep[u];
}
}
void lca_st(int root,int n){
cnt=0;
dfs(root,0);
st.init(2*n-1);
}
int lca_find(int x,int y){return F[st.query(P[x],P[y])]; }
int main()
{
while(scanf("%d%d",&n,&k)!=EOF){
go(i,1,n)h[i]=0;tot=0;
go(i,1,n)scanf("%d",&a[i]);
go(i,2,n)scanf("%d%d",&u,&to),add(u,to);
lca_st(1,n);
int dp[n+10][k+10];
memset(dp,inf,sizeof dp);dp[0][0]=0;
go(i,1,n)go(j,0,k){
if(j>i)continue;
dp[i][j]=min(dp[i-1][j],dp[i][j]);
if(j)dp[i][j]=min(dp[i][j],dp[i-1][j-1]+dep[a[i]]);
if(i>2&&j)dp[i][j]=min(dp[i][j],dp[i-2][j-1]+dep[lca_find(a[i],a[i-1])]);
}
printf("%d\n",dp[n][k]);
}
return 0;
}