[LOJ 2541]「PKUWC2018」猎人杀

题目链接

LOJ 2541

做法

以$sum(S)$表示S这个集合中所有人的仇恨值的和。
我们先求$P(S)$表示S中所有的人一定死在第一个人之后的概率,其他人随意。
显然其他人什么时候死我不关心,只考虑下一枪打在S中或者打在1上,所以$P(S)=\frac{a_1}{sum(S)+a_1}$。
考虑最后计算答案,所有人都不能死在1后面。
通过容斥可以得到每个状态S的贡献为,$P(S)*(-1)^{|S|}$
直接计算显然复杂度过大,但是$sum(S)$并不大。
由于容斥系数有-1和1,把方案变成系数和即可。
用$f[i][j]$表示从第二个人开始前i个人,仇恨值为j的系数和。
转移时,考虑第i个人选不选。
如果不选$f[i][j]=f[i-1][j]$
如果选的话,|S|的奇偶性改变,所以系数取负,$f[i][j]-=f[i-1][j-a[i]]$
最后计算答案枚举j即可。
这样可以获得50分的成绩。
发现每次转移等价于乘上一个生成函数$(x^0-x^{a[i]})$,用分治+NTT解决。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<cstring>
#include<iostream>
#include<cmath>
#include<vector>
#define LL long long
#define N (100005)
using namespace std;
int n,val,ans,G=3,invG,invL,lim;
int a[N],R[N<<2];
vector <int> f;
const int P=998244353;
template <typename T> void read(T&t) {
t=0;
bool fl=true;
char p=getchar();
while (!isdigit(p)) {
if (p=='-') fl=false;
p=getchar();
}
do {
(t*=10)+=p-48;p=getchar();
}while (isdigit(p));
if (!fl) t=-t;
}
int Inc(int a,int b){
return (a+b>=P)?(a+b-P):(a+b);
}
int ksm(int a,int b){
int ret=1;
for (;b;b>>=1,a=1ll*a*a%P) if (b&1) ret=1ll*ret*a%P;
return ret;
}
void NTT(vector <int> &f,int lim,int G){
for (int i=0;i<lim;i++){
R[i]=(R[i>>1]>>1)|((i&1)*lim>>1);
if (i>R[i]) swap(f[i],f[R[i]]);
}
for (int i=1;i<lim;i<<=1){
int w0=ksm(G,(P-1)/(i<<1));
for (int j=0;j<lim;j+=(i<<1)){
int w=1;
for (int k=j;k<i+j;k++){
int t=1ll*f[k+i]*w%P;
f[k+i]=Inc(f[k]-t,P);
f[k]=Inc(f[k],t);
w=1ll*w*w0%P;
}
}
}
int invL=ksm(lim,P-2);
if (G!=3){
for (int i=0;i<lim;i++) f[i]=1ll*f[i]*invL%P;
}
}
int solve(int l,int r,vector <int> &f){
if (l==r){
f.resize(a[l]+1);
f[a[l]]=P-1,f[0]=1;
return a[l];
}
int mid=l+r>>1;
vector <int> f0,f1;
int len=solve(l,mid,f0)+solve(mid+1,r,f1);
for (lim=1;lim<=len;lim<<=1);
f0.resize(lim),f1.resize(lim),f.resize(lim);
NTT(f0,lim,G),NTT(f1,lim,G);
for (int i=0;i<lim;i++) f[i]=1ll*f0[i]*f1[i]%P;
NTT(f,lim,invG);
return len;
}
int main(){
read(n);
for (int i=1;i<=n;i++) read(a[i]),val+=a[i];
invG=ksm(G,P-2);
solve(2,n,f);
for (int j=0;j<val;j++){
ans=Inc(ans,1ll*a[1]*ksm(a[1]+j,P-2)%P*f[j]%P);
}
printf("%d",ans);
return 0;
}