轉化下就是給出一棵樹求最少要斷掉多少條邊才能分隔出一個有p個節點的子樹
這道樹形dp就相對難一點了,我寫的中還用到了數組的滾動,整理出來方程即
f[root][j]=min{f2[root][j]+1,f2[root][k]+f[node][j-k]}
前面那部分,即把root->node這條邊分隔開來
而後面那部分,即是前面保留k個結點,而以node爲根的子樹保留j-k個結點
這裏循環完後,f[root][j]含義是以root爲根的樹中分隔出j個結點的以root爲根的子樹所需最少斷邊數
所以答案並非直接的f[troot][p],而是min{f[troot][p],f[tnode][p]+1} (若該子樹根結點不是總的樹的根節點,還需多分離
一個邊)
代碼如下:
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Iterator;
import java.util.StringTokenizer;
import java.util.Vector;
class Reader{
static BufferedReader reader;
static StringTokenizer tokenizer;
static void init(InputStream input) {
reader=new BufferedReader(new InputStreamReader(input));
tokenizer=new StringTokenizer("");
}
static String next() throws IOException{
while (!tokenizer.hasMoreTokens()) {
tokenizer=new StringTokenizer(reader.readLine());
}
return tokenizer.nextToken();
}
static int nextInt() throws IOException{
return Integer.parseInt(next());
}
}
public class Main {
/**
* @param args
*/
static int n,p,u,v,root,ans;
static Vector<Integer> vector[];
static boolean nroot[];
static int depth[],size[];
static int f[][],f2[][];
private static int getmin(int a,int b) {
return a<b?a:b;
}
private static void dfs(int u,int dep) {
Iterator<Integer> iter=vector[u].iterator();
if (dep==1) depth[u]=0;
else depth[u]=1;
int node;
int psize=1;
size[u]=1;
f[u][1]=0;
while (iter.hasNext()) {
node=iter.next();
psize=size[u];
dfs(node,dep+1);
size[u]+=size[node];
for (int i=1;i<=psize;i++)
f2[u][i]=f[u][i];
for (int i=1;i<=size[u];i++) {
f[u][i]=n+1;
if (i<=psize) f[u][i]=f2[u][i]+1;
for (int k=1;k<=i-1;k++)
if ((k<=psize)&(i-k<=size[node]))
f[u][i]=getmin(f[u][i],f2[u][k]+f[node][i-k]);
//System.out.println(u+" "+i+" "+node+" "+f[u][i]);
}
}
}
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
Reader.init(System.in);
n=Reader.nextInt();
p=Reader.nextInt();
vector=new Vector[n+1];
size=new int[n+1];
nroot=new boolean[n+1];
depth=new int[n+1];
for (int i=1;i<=n;i++)
vector[i]=new Vector<Integer>();
for (int i=1;i<=n-1;i++) {
u=Reader.nextInt();
v=Reader.nextInt();
vector[u].add(v);
nroot[v]=true;
}
for (int i=1;i<=n;i++)
if (!nroot[i]) {
root=i;
break;
}
f=new int[n+1][n+1];
f2=new int[n+1][n+1];
dfs(root,1);
ans=n+1;
for (int i=1;i<=n;i++)
if (size[i]>=p)
ans=getmin(ans,f[i][p]+depth[i]);
System.out.println(ans);
}
}