【題解】LuoGu5299:[PKUWC2018]Slay the Spire

原題傳送門
期望 (計數)DP
看起來的期望,他就是讓你輸出總和

首先明確一個出牌策略(假設已經拿到了m張牌)
因爲所有強化牌都大於1,所以按照強化牌權值從大到小能用就用
最多用(k1)(k-1)張強化牌,不過強化牌不夠我也沒辦法
剩下的牌我出攻擊牌,這樣的策略是最優的

需要找出一個O(n2)O(n^2)的做法

如果枚舉最終mm張牌中有ii張強化牌
並令fi,jf_{i,j}表示前ii張強化牌選了jj張,第ii張牌一定選的總強化乘積和
gi,jg_{i,j}表示前ii張攻擊牌選了jj張,第ii張牌一定選的總攻擊和
根據這三個玩意可以推出答案,分類討論

  • i<k1i<k-1,這時候所有強化牌應該全出,那麼攻擊牌出(ki)(k-i)張,如果枚舉前jj張攻擊牌裏出(ki)(k-i)張牌,就是gj,kig_{j,k-i},但是總共要選擇mm張牌,所以剩下的牌從njn-j張攻擊牌裏去選,方案數爲CnjmkC_{n-j}^{m-k}
    所以ans=j=0nfj,ij=0n(gj,kiCnjmk)ans=\sum_{j=0}^{n}f_{j,i}*\sum_{j=0}^{n}(g_{j,k-i}*C_{n-j}^{m-k})
  • i>=k1i>=k-1,這時候出(k1)(k-1)張強化牌,剩下的強化牌隨便選,出1張攻擊牌,剩下的攻擊牌也隨便選
    所以ans=j=0n(fj,k1Cnjik+1)j=0n(gj,1Cnjmi1)ans=\sum_{j=0}^{n}(f_{j,k-1}*C_{n-j}^{i-k+1})*\sum_{j=0}^{n}(g_{j,1}*C_{n-j}^{m-i-1})

問題變成怎麼求f,gf,g
Fi,jF_{i,j}表示前ii張強化牌選了jj張,Gi,jG_{i,j}同理
f,gf,g的定義相比只是少了是否強制選擇第ii張的限制條件,作爲輔助dp數組

轉移方程:

  • Fi,j=Fi1,j+Fi1,j1aiF_{i,j}=F_{i-1,j}+F_{i-1,j-1}*a_i
  • fi,j=Fi1,j1aif_{i,j}=F_{i-1,j-1}*a_i
  • Gi,j=Gi1,j+Gi1,j1+biCi1j1G_{i,j}=G_{i-1,j}+G_{i-1,j-1}+b_i*C_{i-1}^{j-1}
  • gi,j=Gi1,j1+biCi1j1g_{i,j}=G_{i-1,j-1}+b_i*C_{i-1}^{j-1}

再注意一下邊界就好了

Code:

#include <bits/stdc++.h>
#define maxn 3010
#define LL long long
using namespace std;
const LL qy = 998244353;
LL fac[maxn], inv[maxn], f[maxn][maxn], g[maxn][maxn], F[maxn][maxn], G[maxn][maxn], a[maxn], b[maxn], ans;
int n, m, k;

inline int read(){
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

bool cmp(int x, int y){ return x > y; }
LL C(int n, int m){ return m > n ? 0 : fac[n] * inv[n - m] % qy * inv[m] % qy; }

LL pow(LL n, LL k){
	if (!k) return 1;
	LL sum = pow(n, k >> 1);
	sum = sum * sum % qy;
	if (k & 1) sum = sum * n % qy;
	return sum;
}

int main(){
	fac[0] = 1;
	for (int i = 1; i < maxn; ++i) fac[i] = fac[i - 1] * i % qy;
	inv[maxn - 1] = pow(fac[maxn - 1], qy - 2);
	for (int i = maxn - 2; i >= 0; --i) inv[i] = inv[i + 1] * (i + 1) % qy;
	int T = read();
	while (T--){
		n = read(), m = read(), k = read(), ans = 0;
		for (int i = 1; i <= n; ++i) a[i] = read();
		for (int i = 1; i <= n; ++i) b[i] = read();
		sort(a + 1, a + 1 + n, cmp);
		sort(b + 1, b + 1 + n, cmp);
		F[0][0] = f[0][0] = 1;
		for (int i = 1; i <= n; ++i)
			for (int j = 0; j <= i; ++j){
				F[i][j] = F[i - 1][j], G[i][j] = G[i - 1][j];
				if (j) (F[i][j] += (f[i][j] = F[i - 1][j - 1] * a[i] % qy)) %= qy,
					   (G[i][j] += (g[i][j] = (G[i - 1][j - 1] + b[i] * C(i - 1, j - 1) % qy) % qy)) %= qy;
			}
		for (int i = 0; i <= k - 1; ++i){
			LL s1 = 0, s2 = 0;
			for (int j = i; j <= n; ++j) (s1 += f[j][i]) %= qy;
			for (int j = k - i; j <= n; ++j) (s2 += g[j][k - i] * C(n - j, m - k) % qy) %= qy;
			(ans += s1 * s2 % qy) %= qy;
		}
		for (int i = k; i <= min(m, n); ++i){
			LL s1 = 0, s2 = 0;
			for (int j = k - 1; j <= n; ++j) (s1 += f[j][k - 1] * C(n - j, i - k + 1) % qy) %= qy;
			for (int j = 1; j <= n; ++j) (s2 += g[j][1] * C(n - j, m - i - 1) % qy) %= qy;
			(ans += s1 * s2 % qy) %= qy;
		}
		printf("%lld\n", ans);
	}
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章