hdu 5468 容斥加樹形dp

容斥,num統計樹中含有因子i的結點有多少個,然後用總結點數減去就可以了,注意now-pre,因爲肯定要減,因爲當從一個子樹到另一個子樹的時候,也就是另一個子樹的num值還沒加進來的時候,num值統計的是全局的,所以要把子樹外面的值給減掉~

#include<iostream>
#include<algorithm>
#include<string>
#include<map>//int dx[4]={0,0,-1,1};int dy[4]={-1,1,0,0};
#include<set>//int gcd(int a,int b){return b?gcd(b,a%b):a;}
#include<vector>
#include<cmath>
#include<stack>
#include<string.h>
#include<stdlib.h>
#include<cstdio>
#define mod 1e9+7
#define ll long long
using namespace std;
int n,a,b;
vector<int> prime_element[100005];
int y[100005],ans[100005],num[100005];//num[i]表示子樹裏因子爲i的有多少個 

int head[100005],e;
struct Edge{
    int to,next;
}edge[100005*2];

void add(int u,int v)
{
   edge[e].to=v;
   edge[e].next=head[u];
   head[u]=e++;
}

int calc(int p,int q){  
    int res=0;
    for(int i=1;i<(1<<prime_element[p].size());i++){
        int w=1;
        int cnt=0;
        for(int j=0;j<prime_element[p].size();j++){
            if(i&(1<<j)){
                cnt++;  //這個因子由多少素因子構成 
                w*=prime_element[p][j];
            }
        }
        if(cnt%2)  //容斥 
            res+=num[w];
        else
            res-=num[w];
        num[w]+=q; //q=1時,表示要退出這顆子樹了,把根節點的因子算上 
    }
    return res;
}

int dfs(int p,int q){
    int pre=calc(y[p],0);
//  cout <<" p = "<<p <<" pre = "<<pre<<endl;
    int s=0;
    for(int i=head[p];~i;i=edge[i].next){
        int v=edge[i].to;
        if(v==q)
            continue;
        s+=dfs(v,p);       
    }
    int now=calc(y[p],1);
    ans[p]=s-(now-pre); //總結點-與根節點不互質節點數 
    if(y[p]==1)
        ans[p]++;  //是1的話它和本身也互質
    return s+1; 
}
int main(){
    for(int i=2;i<=100000;++i){
        if(!prime_element[i].empty())
            continue;
        for(int j=i;j<=100000;j+=i){
            prime_element[j].push_back(i); //存放j的所有素因子【模板】
        }
    }
    int cnt=0;
    while(~scanf("%d",&n)){
        e=0;
        memset(num,0,sizeof(num));
        memset(head,-1,sizeof(head));
        for(int i=0;i<n-1;++i){
            scanf("%d%d",&a,&b);
            add(a,b);
            add(b,a); 
        } 
        for(int i=1;i<=n;++i)
            scanf("%d",&y[i]);
        dfs(1,-1);
        printf("Case #%d:",++cnt);
        for(int i=1;i<=n;i++){
           printf(" %d",ans[i]);
        }
        printf("\n");
    }
} 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章