2019hdu多校六 Ridiculous Netizens(點分治)

大概題意是: 給你一顆無根樹,每一個結點有點權, 有多少顆子樹的結點乘積不超過m?子樹的定義是樹上的連通塊。

首先我們考慮另一個問題,假設給你一顆有根樹,所以包含根的子樹有多少種滿足乘積不超過m?

考慮樹形dp的做法,定義dp[i][j]是在i被選取後 這顆子樹中乘積爲j的子樹方案數,這樣每次將兩顆子樹合併的複雜度是m*m的。但實際上子樹大小限制了狀態數不會那麼多,所以每次計算一個點的貢獻的複雜度是 o(m)的。

這裏題解有一個很巧妙的求法,因爲如果一個點被選取,那麼他的父親一定被選取。如果一個點不被選取,那麼它子樹中的所有點也不會被選取。所以當我們從fa->u時,dp[fa]已經包含了u不被選取的情況(因爲u還沒有被搜索,沒有任何節點統計了貢獻)。那麼我們只需要統計u被選取的情況,因爲u被選取了,那麼fa一定也被選取了,所以u可以繼承fa的信息,然後繼續遞歸即可。

但是這樣dfs一次的複雜度是n*m的,還是有點吃不消。注意題目中要求的是乘積。我們可以考慮剩餘的子樹大小,比如當前有兩顆乘積爲m,m-1的子樹,那麼他們都只能在添加大小爲1的子樹,所以狀態可以合併。其實就是一個整除分塊。所以dp[i][j]表示i點在選取後還能添加大小爲j的子樹的方案數,這樣狀態數只有^{\sqrt{m}}。複雜度是o(n*^{\sqrt{m}}).

包含根的子樹統計後,只要太統計一遍不包含根情況,只要把根標記一下,遞歸根的子樹,這裏就可以點分治了。

#include <bits/stdc++.h>

using namespace std;

#define N 2025
#define ll long long
#define mod 1000000007
#define go(i,a,b) for(int i=(a);i<=(b);i++)
#define dep(i,a,b) for(int i=(a);i>=(b);i--)
#define pb push_back
#define inf 0x3f3f3f3f
#define ld long double
#define pii pair<int,int>
#define vi vector<int>
#define add(a,b) (a+=(b)%mod)%=mod
#define lowb(x,c,len) lower_bound(c+1,c+len+1,x)-c
#define uppb(x,c,len) upper_bound(c+1,c+len+1,x)-c
#define ls i*2+1
#define rs i*2+2
#define mid (l+r)/2
#define lson l,mid,ls
#define rson mid+1,r,rs
int n,m,sz,cnt,tot,root,ans,las;
int h[N],sum[N],mson[N],vis[N],a[N],w[N],f[N*N],dp[N][N];
struct no{
    int to,n;
};no eg[N*2];
void link(int u,int to){
    eg[++tot]={to,h[u]};h[u]=tot;
    eg[++tot]={u,h[to]};h[to]=tot;
}
void getroot(int u,int fa){
    sum[u]=1;mson[u]=0;
    for(int i=h[u];i;i=eg[i].n){
        int to=eg[i].to;
        if(to==fa||vis[to])continue;
        getroot(to,u);
        sum[u]+=sum[to];
        mson[u]=max(mson[u],sum[to]);
    }
    mson[u]=max(mson[u],sz-sum[u]);
    if(mson[u]<mson[root])root=u;
}
void dfs(int u,int fa){
    go(i,1,cnt)dp[u][i]=0;
    go(i,1,cnt)if(w[i]>=a[u])add(dp[u][f[w[i]/a[u]]],dp[fa][i]);
    for(int i=h[u];i;i=eg[i].n){
        int to=eg[i].to;
        if(to==fa||vis[to])continue;
        dfs(to,u);
        go(i,1,cnt)add(dp[u][i],dp[to][i]);
    }
}
void divide(int u){
    //cout<<u<<endl;
     dp[0][cnt]=1;dfs(u,0);vis[u]=1;
    go(i,1,cnt)add(ans,dp[u][i]);
    for(int i=h[u];i;i=eg[i].n){
        int to=eg[i].to;
        if(vis[to])continue;
        mson[root=0]=sz=sum[to];
        getroot(to,0);divide(root);
    }
}
void solve(){
    mson[root=0]=sz=n;
    getroot(1,-1);divide(root);
    printf("%d\n",ans);
    ans=tot=cnt=las=0;
    go(i,1,n)h[i]=vis[i]=0;
    memset(dp[0],0,sizeof dp[0]);
}
int main()
{
    int T,u,to;cin>>T;while(T--){
        scanf("%d%d",&n,&m);
        dep(i,m,1){
            int x=m/i;
            w[f[x]=(x!=las?++cnt:cnt)]=x;
            las=x;
        }
        go(i,1,n)scanf("%d",&a[i]);
        go(i,2,n)scanf("%d%d",&u,&to),link(u,to);
        solve();
    }
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章