[BZOJ 2839] 集合计数

链接

BZOJ 2839 集合计数

题目大意

一个有$n$个元素的集合有$2^n$个不同子集(包含空集),现在要在这$2^n$个集合中取出若干集合(至少一个),使得它们的交集的元素个数为K,求取法的方案数,对1e9+7取模

做法

直接求比较困难,我们考虑先求出$a_i$表示交集大小至少为i的方案数。
那么这i个数的取法是$\binom{n}{i}$,剩下的$n-i$个元素取或不取有$2^{n-i}$中选择,而这么多集合分别又有取或不取两种选择,所以$a_i=\binom{n}{i}*(2^{2^{n-i} }-1)$减1是因为至少取一个,把全部不取的情况减掉。
知道这个之后我们考虑构造容斥系数$f_i$,使得$ans=\sum\limits_{i=0}^{n}f_i*a_i$
那么怎么求这个$f$呢,我们考虑每种交集恰好为$x$的选法,对于每个$a_i$的贡献的总和。
由于选取的集合已经固定,所以对于每个$a[i]$的贡献就是$\binom{x}{i}$,总贡献为
$\sum\limits_{i=0}^{x}f_i\binom{x}{i}$
定义kronecker delta函数$g(x,k)$,即
$g(x)=[x=k]$
我们对于交集恰好为k的函数应该对最后答案有1的贡献。
所以$g(x)=\sum\limits_{i=0}^{x}f_i\binom{x}{i}$
那么我们可以二项式反演
二项式反演证明戳这里
$f_{n}=\sum\limits_{i=0}^{n}(-1)^{n-i}\binom{n}{i}g(i)$
由于只有$g(k)=1$,所以
$f_i=(-1)^{i-k}\binom{i}{k}*[i>=k]$
再将$f$带入到最上面求$ans$的式子里就行了。
$ans=\sum\limits_{i=k}^{n}(-1)^{i-k}\binom{i}{k}\binom{n}{i}*(2^{2^{n-i} }-1)$
代码实现的时候$(2^{2^{n-i} }-1)$不好处理,
我们可以$i$从倒序枚举,$2^{2^{n}}=(2^{2^{n-1} })^2$计算

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<cstring>
#include<iostream>
#include<cmath>
#define LL long long
#define N (1000005)
using namespace std;
int n,k,d,base,ans;
int jc[N],A[N],inv[N],f[N],a[N];
const int P=1000000007;
template <typename T> void read(T&t) {
t=0;
bool fl=true;
char p=getchar();
while (!isdigit(p)) {
if (p=='-') fl=false;
p=getchar();
}
do {
(t*=10)+=p-48;p=getchar();
}while (isdigit(p));
if (!fl) t=-t;
}
inline int ksm(int a,int b){
int ret=1;
for (;b;b>>=1,a=1ll*a*a%P) if (b&1) ret=1ll*ret*a%P;
return ret;
}
inline int Inc(int a,int b){
return (a+b>=P)?(a+b-P):(a+b);
}
inline int C(int a,int b){
return 1ll*jc[a]*inv[a-b]%P*inv[b]%P;
}
int main(){
read(n),read(k);
A[0]=A[1]=1;
for (int i=2;i<=n;i++) A[i]=Inc(1ll*-P/i*A[P%i]%P,P);
jc[0]=inv[0]=1;
for (int i=1;i<=n;i++){
jc[i]=1ll*jc[i-1]*i%P;
inv[i]=1ll*inv[i-1]*A[i]%P;
}
base=2;
if ((n-k)&1) d=-1;
else d=1;
for (int i=n;i>=k;i--){
ans=(1ll*d*C(i,k)*C(n,i)%P*(base-1)%P+ans+P)%P;
base=1ll*base*base%P;
d=-d;
}
printf("%d",ans);
return 0;
}