[LOJ 2538]「PKUWC2018」Slay the Spire

题目链接

LOJ 2358

前言

既然要去PKUWC,当然要做做去年的题看看自己水平啦。

做法

首先要想到一个结论:如果有$\ge k-1$张翻倍牌的话,一定会先打掉最大的k-1张翻倍牌,然后再打一张点数最大的攻击牌,由于每张翻倍牌至少$*=2$,所以这个结论应该是显然的吧。
再观察一下ans后面乘的那个式子,实际上就是C(2n,m),相当于把求期望变成了求所有方案的和。(如果他不乘的话你就自己除掉就行了,没什么影响)
既然是要求总和,那么
getf(i,j)表示我从n张翻倍牌中选i张,这i张中前j大的乘积,再sigma起来。
getg(i,j)表示我从n张攻击牌中选i张,这i张中前j大的加和,再sigma起来。

我们考虑枚举取i张翻倍牌,分情况讨论,结合上面的结论就可以得到
下面所有代码中的Inc函数,表示把两个加起来对p取模。
C则表示组合数。

1
2
3
4
5
6
7
8
9
for (int i=0;i<=min(m,n);i++){
if (m-i>n) continue;
if (i<k){
ans=Inc(ans,1ll*getf(i,i)*getg(m-i,k-i)%P);
}
else{
ans=Inc(ans,1ll*getf(i,k-1)*getg(m-i,1)%P);
}
}

这里应该没问题吧。
那么我们接下来考虑如何getf和getg.
先说getf吧,先把a数组排序,使翻倍牌的大小递减。
预处理f数组,f[i][j]表示前i张牌取j张,所得成绩的sigma,应该是个简单dp吧

1
2
3
4
5
f[0][0]=1;
for (int i=1;i<=n;i++){
for (int j=1;j<=i;j++) f[i][j]=Inc(f[i-1][j],1ll*f[i-1][j-1]*a[i]%P);
f[i][0]=f[i-1][0];
}

然后我们枚举第j大的取的是第几张翻倍牌,前面的贡献可以从f中导出来,后面是个组合数
博主语言能力有限,还是代码比较好理解。

1
2
3
4
5
6
7
8
int getf(int x,int y){
if (y==0) return C[n][x];
int ret=0;
for (int i=y;i<=n;i++){
ret=Inc(ret,1ll*f[i-1][y-1]*a[i]%P*C[n-i][x-y]%P);
}
return ret;
}

getg也是一样

1
2
3
4
5
6
7
8
9
10
11
12
g[0][0]=0;
for (int i=1;i<=n;i++){
for (int j=1;j<=i;j++) g[i][j]=Inc(g[i-1][j],Inc(g[i-1][j-1],1ll*C[i-1][j-1]*b[i]%P));
g[i][0]=g[i-1][0];
}
int getg(int x,int y){
int ret=0;
for (int i=y;i<=n;i++){
ret=Inc(ret,1ll*Inc(g[i-1][y-1],1ll*b[i]*C[i-1][y-1]%P)*C[n-i][x-y]%P);
}
return ret;
}

代码

把上面每一步拼在一起就好了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<cstring>
#include<iostream>
#include<cmath>
#define LL long long
#define N (3005)
using namespace std;
int T,n,m,k,ans;
int g[N][N],f[N][N],C[N][N],a[N],b[N];
const int P=998244353;
template <typename T> void read(T&t) {
t=0;
bool fl=true;
char p=getchar();
while (!isdigit(p)) {
if (p=='-') fl=false;
p=getchar();
}
do {
(t*=10)+=p-48;p=getchar();
}while (isdigit(p));
if (!fl) t=-t;
}
inline bool cmp(int a,int b){
return a>b;
}
int Inc(int a,int b){
return (a+b>=P)?(a+b-P):(a+b);
}
int getf(int x,int y){
if (y==0) return C[n][x];
int ret=0;
for (int i=y;i<=n;i++){
ret=Inc(ret,1ll*f[i-1][y-1]*a[i]%P*C[n-i][x-y]%P);
}
return ret;
}
int getg(int x,int y){
int ret=0;
for (int i=y;i<=n;i++){
ret=Inc(ret,1ll*Inc(g[i-1][y-1],1ll*b[i]*C[i-1][y-1]%P)*C[n-i][x-y]%P);
}
return ret;
}
int main(){
read(T);
C[0][0]=1;
for (int i=1;i<=3000;i++){
for (int j=1;j<=i;j++) C[i][j]=Inc(C[i-1][j-1],C[i-1][j]);
C[i][0]=1;
}
while (T--){
read(n),read(m),read(k);
for (int i=1;i<=n;i++) read(a[i]);
for (int i=1;i<=n;i++) read(b[i]);
sort(a+1,a+n+1,cmp);
sort(b+1,b+n+1,cmp);
f[0][0]=1;
for (int i=1;i<=n;i++){
for (int j=1;j<=i;j++) f[i][j]=Inc(f[i-1][j],1ll*f[i-1][j-1]*a[i]%P);
f[i][0]=f[i-1][0];
}
g[0][0]=0;
for (int i=1;i<=n;i++){
for (int j=1;j<=i;j++) g[i][j]=Inc(g[i-1][j],Inc(g[i-1][j-1],1ll*C[i-1][j-1]*b[i]%P));
g[i][0]=g[i-1][0];
}
ans=0;
for (int i=0;i<=min(m,n);i++){
if (m-i>n) continue;
if (i<k){
ans=Inc(ans,1ll*getf(i,i)*getg(m-i,k-i)%P);
}
else{
ans=Inc(ans,1ll*getf(i,k-1)*getg(m-i,1)%P);
}
}
printf("%d\n",ans);
}
return 0;
}

结语

希望大家都能在WC中取得好成绩