大概題意是: 給你一顆無根樹,每一個結點有點權, 有多少顆子樹的結點乘積不超過m?子樹的定義是樹上的連通塊。
首先我們考慮另一個問題,假設給你一顆有根樹,所以包含根的子樹有多少種滿足乘積不超過m?
考慮樹形dp的做法,定義dp[i][j]是在i被選取後 這顆子樹中乘積爲j的子樹方案數,這樣每次將兩顆子樹合併的複雜度是m*m的。但實際上子樹大小限制了狀態數不會那麼多,所以每次計算一個點的貢獻的複雜度是 o(m)的。
這裏題解有一個很巧妙的求法,因爲如果一個點被選取,那麼他的父親一定被選取。如果一個點不被選取,那麼它子樹中的所有點也不會被選取。所以當我們從fa->u時,dp[fa]已經包含了u不被選取的情況(因爲u還沒有被搜索,沒有任何節點統計了貢獻)。那麼我們只需要統計u被選取的情況,因爲u被選取了,那麼fa一定也被選取了,所以u可以繼承fa的信息,然後繼續遞歸即可。
但是這樣dfs一次的複雜度是n*m的,還是有點吃不消。注意題目中要求的是乘積。我們可以考慮剩餘的子樹大小,比如當前有兩顆乘積爲m,m-1的子樹,那麼他們都只能在添加大小爲1的子樹,所以狀態可以合併。其實就是一個整除分塊。所以dp[i][j]表示i點在選取後還能添加大小爲j的子樹的方案數,這樣狀態數只有。複雜度是o(n*).
包含根的子樹統計後,只要太統計一遍不包含根情況,只要把根標記一下,遞歸根的子樹,這裏就可以點分治了。
#include <bits/stdc++.h>
using namespace std;
#define N 2025
#define ll long long
#define mod 1000000007
#define go(i,a,b) for(int i=(a);i<=(b);i++)
#define dep(i,a,b) for(int i=(a);i>=(b);i--)
#define pb push_back
#define inf 0x3f3f3f3f
#define ld long double
#define pii pair<int,int>
#define vi vector<int>
#define add(a,b) (a+=(b)%mod)%=mod
#define lowb(x,c,len) lower_bound(c+1,c+len+1,x)-c
#define uppb(x,c,len) upper_bound(c+1,c+len+1,x)-c
#define ls i*2+1
#define rs i*2+2
#define mid (l+r)/2
#define lson l,mid,ls
#define rson mid+1,r,rs
int n,m,sz,cnt,tot,root,ans,las;
int h[N],sum[N],mson[N],vis[N],a[N],w[N],f[N*N],dp[N][N];
struct no{
int to,n;
};no eg[N*2];
void link(int u,int to){
eg[++tot]={to,h[u]};h[u]=tot;
eg[++tot]={u,h[to]};h[to]=tot;
}
void getroot(int u,int fa){
sum[u]=1;mson[u]=0;
for(int i=h[u];i;i=eg[i].n){
int to=eg[i].to;
if(to==fa||vis[to])continue;
getroot(to,u);
sum[u]+=sum[to];
mson[u]=max(mson[u],sum[to]);
}
mson[u]=max(mson[u],sz-sum[u]);
if(mson[u]<mson[root])root=u;
}
void dfs(int u,int fa){
go(i,1,cnt)dp[u][i]=0;
go(i,1,cnt)if(w[i]>=a[u])add(dp[u][f[w[i]/a[u]]],dp[fa][i]);
for(int i=h[u];i;i=eg[i].n){
int to=eg[i].to;
if(to==fa||vis[to])continue;
dfs(to,u);
go(i,1,cnt)add(dp[u][i],dp[to][i]);
}
}
void divide(int u){
//cout<<u<<endl;
dp[0][cnt]=1;dfs(u,0);vis[u]=1;
go(i,1,cnt)add(ans,dp[u][i]);
for(int i=h[u];i;i=eg[i].n){
int to=eg[i].to;
if(vis[to])continue;
mson[root=0]=sz=sum[to];
getroot(to,0);divide(root);
}
}
void solve(){
mson[root=0]=sz=n;
getroot(1,-1);divide(root);
printf("%d\n",ans);
ans=tot=cnt=las=0;
go(i,1,n)h[i]=vis[i]=0;
memset(dp[0],0,sizeof dp[0]);
}
int main()
{
int T,u,to;cin>>T;while(T--){
scanf("%d%d",&n,&m);
dep(i,m,1){
int x=m/i;
w[f[x]=(x!=las?++cnt:cnt)]=x;
las=x;
}
go(i,1,n)scanf("%d",&a[i]);
go(i,2,n)scanf("%d%d",&u,&to),link(u,to);
solve();
}
return 0;
}