Description
傳送門
T<=5,1<=n,m<=2000,1<=li,ri,hi<=1e9
Solution
- 首先很容易想到的是求所有節點都被佔領的概率,也就是方案數(最後再除以總數)。
- 考場上的時候想到了狀態f[x][i]表示x的子樹全部被佔領,從x的子樹上到x節點上的最大能力的殭屍是i。
- 但是難處理的是殭屍會從一個子樹上去再下來到某個節點,單單是考慮x的子樹中的殭屍是不行的,要考慮到所有的殭屍。
- 所以就有了一種很巧妙的方法,將f[x][i]變爲在整棵樹上的殭屍i到了x。有可能這個i不在x的子樹中,但是它可以從父親走過來,所以我們就假設它到了x,並且預先把它的貢獻算出來(這是我覺得這題設的狀態最巧妙並且最難懂的地方)
- 既然已經理解了狀態的話,方程就不難推了。
- 考慮x是當前點,y是它的兒子。假設y->x的邊被k這個殭屍通過的方案爲wk,通不過的方案爲vk。
- f′[x][k]+=f[x][k]∗f[y][k]∗wk,表示x或y如果有k的話,那這個k也通過(x,y)走到另一邊,不用管是從下往上還是從上往下,爲了讓k走過去,就要有wk。
- f′[x][k]+=f[x][k]∗∑t>kf[y][t]∗vt,表示當t比k要強的時候,t不能跳到k。注意t一定在y的子樹內,k一定不在y的子樹內。
- f′[x][k]+=f[x][k]∗vk∗∑t<kf[y][t],即t比k弱時,k不能跳到t。t,k範圍同上。
- f[x]的初值:假設x點起始的殭屍最強的是p,那麼f[x][i]=(i>=p),因爲不管殭屍怎麼走,起始的時候的最大殭屍就一定不小於p。
- 用前綴(後綴)和優化。
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define maxn 2005
#define ll long long
#define mo 998244353
using namespace std;
int T,n,m,i,j,k,x,y,u,v,tot;
int em,e[maxn*2],nx[maxn*2],ls[maxn],L[maxn*2],R[maxn*2];
int hv[maxn][maxn];
struct zom{int s,h;} z[maxn];
int cmp(zom a,zom b){return a.h<b.h;}
ll sum,ans,f[maxn][maxn],g[maxn],s[maxn];
ll ksm(ll x,ll y){
ll s=1;
for(;y;y/=2,x=x*x%mo) if (y&1)
s=s*x%mo;
return s;
}
ll Get(int i,int k,int t){
if (t==1) return min(max(0,z[k].h-L[i]),R[i]-L[i]+1);
return min(max(0,R[i]-z[k].h+1),R[i]-L[i]+1);
}
void dg(int x,int p){
int i,j,y; ll s;
for(i=ls[x];i;i=nx[i]) if (e[i]!=p)
dg(e[i],x);
int tp=1;
for(i=m;i>=1;i--) {
f[x][i]=tp;
if (hv[x][i]) tp=0;
}
for(i=ls[x];i;i=nx[i]) if (e[i]!=p){
y=e[i];
for(j=1;j<=m;j++) hv[x][j]|=hv[y][j];
memcpy(g,f[x],sizeof(g));
memset(f[x],0,sizeof(f[x]));
for(j=1;j<=m;j++) f[x][j]+=g[j]*f[y][j]%mo*Get(i,j,1)%mo;
for(s=0,j=m;j>=1;j--){
if (!hv[y][j]) f[x][j]+=s*g[j]%mo;
if (hv[y][j]) (s+=f[y][j]*Get(i,j,0)%mo)%=mo;
}
for(s=0,j=1;j<=m;j++) {
if (!hv[y][j]) f[x][j]+=s*g[j]%mo*Get(i,j,0)%mo;
if (hv[y][j]) (s+=f[y][j])%=mo;
}
}
for(i=1;i<=m;i++) f[x][i]%=mo;
}
int main(){
scanf("%d",&T);
while (T--){
scanf("%d%d",&n,&m);
em=0,memset(ls,0,sizeof(ls));
sum=1;
for(i=1;i<n;i++){
scanf("%d%d%d%d",&x,&y,&u,&v);
em++; e[em]=y; nx[em]=ls[x]; ls[x]=em; L[em]=u,R[em]=v;
em++; e[em]=x; nx[em]=ls[y]; ls[y]=em; L[em]=u,R[em]=v;
sum=sum*(v-u+1)%mo;
}
for(i=1;i<=m;i++) scanf("%d%d",&z[i].s,&z[i].h);
sort(z+1,z+1+m,cmp);
memset(hv,0,sizeof(hv));
for(i=1;i<=m;i++) hv[z[i].s][i]=1;
dg(1,0);
for(ans=0,i=1;i<=m;i++) ans+=f[1][i];
printf("%lld\n",((sum-ans%mo)%mo*ksm(sum,mo-2)%mo+mo)%mo);
}
}