Appleman has a tree with n vertices. Some of the vertices (at least one) are colored black and other vertices are colored white.
Consider a set consisting of k (0 ≤ k < n) edges of Appleman’s tree. If Appleman deletes these edges from the tree, then it will split into (k + 1) parts. Note, that each part will be a tree with colored vertices.
Now Appleman wonders, what is the number of sets splitting the tree in such a way that each resulting part will have exactly one black vertex? Find this number modulo 1000000007 (109 + 7).
Input
The first line contains an integer n (2 ≤ n ≤ 105) — the number of tree vertices.
The second line contains the description of the tree: n - 1 integers p 0, p 1, …, p n - 2 (0 ≤ p i ≤ i). Where p i means that there is an edge connecting vertex (i + 1) of the tree and vertex p i. Consider tree vertices are numbered from 0 to n - 1.
The third line contains the description of the colors of the vertices: n integers x 0, x 1, …, x n - 1 ( x i is either 0 or 1). If x i is equal to 1, vertex i is colored black. Otherwise, vertex i is colored white.
Output
Output a single integer — the number of ways to split the tree modulo 1000000007 (109 + 7).
Examples
Input
3
0 0
0 1 1
Output
2
Input
6
0 1 1 0 4
1 1 0 0 1 0
Output
1
Input
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1
Output
27
題意:
一棵樹,每個點可以是黑色也可以是白色。
要求減掉個邊分成塊使得每個塊只有一個黑點
思路:
定義代表爲根節點子樹有沒有黑點,有一個黑點的分法。
則對於,
如果當前子樹有黑點,要去掉這個邊,也就是
如果當前子樹有黑點,可以保留這個邊,也就是
如果當前子樹沒有黑點,那麼肯定要保留這個邊,否則子樹連通塊沒有黑點了,也就是
對於,
如果當前子樹沒有黑點,則保留這個邊,也就是
如果當前子樹有黑點,則減掉這個邊,也就是
然後轉移就好了。
對於這種樹dp計數的問題總不是很感冒,感覺一般就是,乘法原理加上之前子樹劃分的狀態和當前子樹的狀態,最後可能還要去重。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <set>
#include <queue>
#include <map>
#include <string>
#include <iostream>
#include <cmath>
using namespace std;
typedef long long ll;
typedef long long ll;
const int maxn = 1e5 + 7;
const int mod = 1e9 + 7;
vector<int>G[maxn];
int a[maxn];
ll dp[maxn][2];
void dfs(int u,int fa) {
if(a[u] == 1) {
dp[u][1] = 1;
} else {
dp[u][0] = 1;
}
for(int i = 0;i < G[u].size();i++) {
int v = G[u][i];
if(v == fa) continue;
dfs(v,u);
dp[u][1] = (dp[u][1] * dp[v][0] % mod + dp[u][0] * dp[v][1] % mod + dp[u][1] * dp[v][1] % mod) % mod;
dp[u][0] = (dp[u][0] * dp[v][0] % mod + dp[u][0] * dp[v][1] % mod) % mod;
}
}
int main() {
int n;scanf("%d",&n);
for(int i = 2;i <= n;i++) {
int x;scanf("%d",&x);x++;
G[x].push_back(i);
G[i].push_back(x);
}
for(int i = 1;i <= n;i++) scanf("%d",&a[i]);
dfs(1,-1);
printf("%lld\n",dp[1][1]);
return 0;
}