LOJ #6496. 「雅禮集訓 2018 Day1」仙人掌(分治FFT,仙人掌DP)

題目
沒什麼好說的。
考慮樹,分治FFT即可。
上環,環上每個除了父親的點求出方案數後枚舉與父親相連的兩條邊的方向後O(len)dpO(len)dp求方案數即可。

AC Code\mathcal AC \ Code

#include<bits/stdc++.h>
#define maxn 300005
#define mod 998244353
using namespace std;

char cb[1<<16],*cs=cb,*ct=cb;
#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<16,stdin),cs==ct)?0:*cs++)
void read(int &res){
	char ch;
	for(;!isdigit(ch=getc()););
	for(res=ch-'0';isdigit(ch=getc());res=res*10+ch-'0');
}

int n,m,a[maxn],in[maxn];
int info[maxn],Prev[maxn<<1],to[maxn<<1],cnt_e=1;
void Node(int u,int v){ Prev[++cnt_e]=info[u],info[u]=cnt_e,to[cnt_e]=v; }
int fir[maxn<<1],nxt[maxn<<2],tar[maxn<<2],cnte,tot;
void add(int u,int v){ nxt[++cnte] = fir[u] , fir[u] = cnte , tar[cnte] = v; }

// NTT
int Wl,Wl2,w[maxn],lg[maxn],inv[maxn];
int Pow(int b,int k){ int r=1;for(;k;k>>=1,b=1ll*b*b%mod) if(k&1) r=1ll*r*b%mod;return r; }
void init(int n){
	for(Wl=w[0]=inv[0]=inv[1]=1;n>=Wl<<1;Wl<<=1);w[1]=Pow(3,(mod-1)/(Wl2=Wl<<1));
	for(int i=2;i<=Wl2;i++) w[i]=1ll*w[i-1]*w[1]%mod,lg[i]=lg[i>>1]+1,inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
void NTT(int *A,int n,int tp){
	static int r[maxn];
	for(int i=1;i<n;i++) i<(r[i]=r[i>>1]>>1|(i&1)<<lg[n]-1) && (swap(A[i],A[r[i]]),0);
	for(int L=1,B=Wl;L<n;L<<=1,B>>=1)for(int s=0;s<n;s+=L<<1) for(int k=s,x=0,t;k<s+L;k++,x+=B)
		t=1ll*w[tp^1?Wl2-x:x]*A[k+L]%mod,A[k+L]=(A[k]-t)%mod,A[k]=(A[k]+t)%mod;
	if(tp^1) for(int i=0;i<n;i++) A[i]=1ll*A[i]*inv[n]%mod; 
}
void MUL(int *A,int *B,int *C,int n,int m){
	static int st[2][maxn];
	int L = 1 << lg[n+m] + 1;
	for(int i=0;i<L;i++)
		st[0][i] = i <= n ? A[i] : 0,
		st[1][i] = i <= m ? B[i] : 0;
	NTT(st[0],L,1),NTT(st[1],L,1);
	for(int i=0;i<L;i++)
		C[i] = 1ll * st[0][i] * st[1][i] % mod;
	NTT(C,L,-1);
}

int dfn[maxn],low[maxn],tim;
void dfs0(int u,int fe){
	dfn[u] = low[u] = ++tim;
	static int sta[maxn]={};
	sta[++sta[0]] = u;
	for(int i=info[u],v;i;i=Prev[i]) if(!dfn[v=to[i]]){
		dfs0(v,i);
		low[u] = min(low[u] , low[v]);
		if(low[v] >= dfn[u]){
			if(low[v] > dfn[u]){
				add(v,u),add(u,v);
				for(int t=-1;t!=v;)
					t=sta[sta[0]--];
			}
			else{
				++tot;
				for(int t=-1;t!=v;){
					t=sta[sta[0]--];
					add(tot,t),add(t,tot);
				}
				add(tot,u),add(u,tot);
			}
		}
	}
	else if(fe ^ i ^ 1) low[u] = min(low[u] , dfn[v]);
}

int ar[maxn][3],f[maxn][3],cnt;
int *g[maxn<<2],sz[maxn<<2];

#define lc u<<1
#define rc lc|1
void Sol1(int u,int l,int r){
	if(u>1)g[u] = new int [1 << lg[r-l+1] + 2];
	if(l==r) return (void)(g[u][0] = ar[l][0] , g[u][1] = ar[l][1] , g[u][2] = ar[l][2] , sz[u] = ar[l][2] ? 2 : 1);
	int m=l+r>>1;
	Sol1(lc,l,m),Sol1(rc,m+1,r);
	sz[u] = sz[lc] + sz[rc];
	MUL(g[lc],g[rc],g[u],sz[lc],sz[rc]);
}

void dfs1(int u,int ff){ 
	for(int i=fir[u],v;i;i=nxt[i]) if((v=tar[i]) ^ ff)
		dfs1(v,u);
	cnt = 0;
	if(u <= n){
		for(int i=fir[u],v;i;i=nxt[i]) if((v=tar[i])^ff){
			++cnt;
			if(v <= n){
				ar[cnt][0] = f[v][1];
				ar[cnt][1] = f[v][0];
				ar[cnt][2] = 0;
			}
			else{
				ar[cnt][0] = f[v][2],
				ar[cnt][1] = f[v][1],
				ar[cnt][2] = f[v][0];
			}
		}
		g[1]=new int[1<<lg[max(a[u],cnt)]+2];
		for(int i=0;i<=a[u];i++) g[1][i] = 0;
		if(cnt)
			Sol1(1,1,cnt);
		else
			g[1][0] = 1;
		for(int i=1;i<=a[u];i++)
			g[1][i] = (g[1][i] + g[1][i-1]) % mod;
		if(a[u]>1) f[u][2] = g[1][a[u]-2];
		if(a[u]) f[u][1] = g[1][a[u]-1];
		f[u][0] = g[1][a[u]];
	}
	else{
		for(int i=fir[u],v;i;i=nxt[i]) if((v=tar[i])^ff){
			++cnt;
			ar[cnt][2] = f[v][2],
			ar[cnt][1] = f[v][1],
			ar[cnt][0] = f[v][0];
		}
		static int dp[2][2][2];
		memset(dp,0,sizeof dp);
		int nw = 1 , pr = 0;
		dp[nw][0][0] = dp[nw][1][1] = 1;
		for(int i=1;i<=cnt;i++){
			swap(nw,pr);
			for(int j=0;j<2;j++)
				for(int k=0;k<2;k++) if(dp[pr][j][k]){
					dp[nw][j][0] = (dp[nw][j][0] + dp[pr][j][k] * 1ll * ar[i][1+k]) % mod;
					dp[nw][j][1] = (dp[nw][j][1] + dp[pr][j][k] * 1ll * ar[i][k]) % mod;
					dp[pr][j][k] = 0;
				}
		}
		f[u][2] = dp[nw][1][0];
		f[u][1] = (dp[nw][0][0] + dp[nw][1][1]) % mod;
		f[u][0] = dp[nw][0][1];
	}
}

int main(){
	read(n),read(m);tot = n;
	init(2*n);
	for(int i=1,u,v;i<=m;i++){
		read(u),read(v);in[u]++,in[v]++;
		Node(u,v),Node(v,u);
	}
	for(int i=1;i<=n;i++) read(a[i]),a[i]=min(a[i],in[i]);
	dfs0(1,0);
	dfs1(1,0);
	printf("%d\n",(f[1][0]+mod)%mod);
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章