ARC104 D.Multiset Mean
Problem Statement
元素范围为 1→n,每个元素最多出现 k 次,对于所有可能的集合,求出其平均数,问对于任意 i∈[1,n] 作为平均数的方案数是多少。
Constraints
1≤N,K≤100
10^8≤M≤10^9+9
M is prime.
All values in input are integers.
Solution
考虑一个集合 S 的平均数为 x ,那么先 x 的个数不会对平均数产生影响,故先不考虑。
集合和中的数可以被分为大于 x 和小于 x 两类,不妨设
P={x_i|x_i\in S,x_i<x }
Q={x_i\in S,x_i>x}
显然有 \sum_{x_i \in P} x-x_i=\sum_{x_i \in Q} x_i-x
亦 x-x_i 代替 P 中的 x_i,以 x_i-x 代替 Q 中的 x_i,
P 中的数可能有 1,2,\dots ,x-1 ,Q 中的数可能有 1,2,\dots,n-x
于是我们用 dp[i][j] 表示用前 i 个数,每个数最多用 K 次,和为 j 的结果是多少
如果不考虑每个数最多用 k 次的限制,那么这就是一个完全背包问题,而考虑了这个限制以后,
我们删除 每个数超过 k 次的方案数,即可得到答案。
dp[0][0]=1;
int T = 0;
for(int i=1;i<=n;++i){
T += i*k;
for(int j=0;j<=T;++j){
dp[i][j]=dp[i-1][j];
if(j>=i) Add(dp[i][j],dp[i][j-i]);//完全背包
}
for(int j=T;j>=i*(k+1);--j){
Add(dp[i][j],m-dp[i][j-i*(k+1)]);//删除使用了超过k个i的转移
}
}
为什么只用删除 k+1 个 i 的转移,而不用考虑 k+2,k+3,\dots 呢?因为我们跑的是完全背包,j-i*(k+1) 的转移中已经包含了那些情况了。
At 的官方题解讲了一个同余项合并的优化,考虑对于 dp[i][j] ,暴力的转移就是
$ dp[i][j]=\sum dp[i-1][j-pi]其中1 \leq p\leq k , j-pi\geq 0我们可以发现j-p*i非常像j \mod i于是我们每次只用把当前算的值加到一个统计表示j\mod i的数组s[j\%i] 即可,注意要删去使用了超过k$ 次的转移。非常妙~
At官方代码
for (int i = 1; i <= N; ++i) {
V<Mint> ps(i);
s += i;
for (int j = 0; j <= s * K; ++j) {
int x = j % i;
ps[x] += dp[i - 1][j];
if (j - i * (K + 1) >= 0) ps[x] -= dp[i - 1][j - i * (K + 1)];
dp[i][j] = ps[x];
}
}
然后我们对于 x=i 的答案,P, Q 两个集合和分别为 j 时的方案数相乘, $j:1\to nnk然后考虑i的影响,分别可以放0,1,2,\dots,k个i。故乘以k+1,但是最后答案要-1$ 。
因为要删去空集的情况。
for(int i=1;i<=n;++i){
ll ans = 0 ;
for(int j=0;j<=n*n*k;++j)
Add(ans,dp[i-1][j]*dp[n-i][j]%m);
ans=(ans)*(k+1)%m;
ans--;
printf("%lld\n",ans);
}
Code
#include<bits/stdc++.h>
#define ll long long
const int N = 101;
int n,k;
ll dp[N][N*N*N],s[N],m;
void Add(ll &x,ll y){
x=(x+y)%m;
}
int main(){
scanf("%d%d%lld",&n,&k,&m);
dp[0][0]=1;
int T = 0;
for(int i=1;i<=n;++i){
T += i*k;
for(int j=0;j<=T;++j){
dp[i][j]=dp[i-1][j];
if(j>=i) Add(dp[i][j],dp[i][j-i]);
}
for(int j=T;j>=i*(k+1);--j){
Add(dp[i][j],m-dp[i][j-i*(k+1)]);
}
}
for(int i=1;i<=n;++i){
ll ans = 0 ;
for(int j=0;j<=n*n*k;++j)
Add(ans,dp[i-1][j]*dp[n-i][j]%m);
ans=(ans)*(k+1)%m;
ans--;
printf("%lld\n",ans);
}
return 0;
}
0 条评论