【精進】ABC295 E - Kth Number
公式解説の言い換えが思いつかなかったので遠回り
問題
E - Kth Number
$0$以上$M$以下の整数からなる長さ$N$の数列$A$が与えられる。$A_i = 0$を満たすそれぞれの$i$について$、A_i$を$1$以上$M$以下の一様ランダムに選んだ整数で書き換える。その後$A$を昇順にソートしたとき、$A_K$の期待値を求めよ。
$N,M\le 2000$
解法
数列の添字は0-indexedで考える。特に$K$は入力から$1$デクリメントしている。
$A_i=0$となるインデックスの数を$c$、操作後に$A_K = i$となる場合の数を$B_i$とすると、$\sum_{i=1}^M i B_i/M^c $が求める期待値である。$B_i$を求める。$l,r$をそれぞれ$1\le A_j<i$,$i<A_j\le M$となる$j$の数とする。$c$個の内、$s$個を$i$未満に、$t$個を$i$に、$u$個を$i+1$以上に割り振るのは
通りある。操作後の$A_K$が$i$となるためには$i$未満の個数が$K$個以下、$i+1$以上の要素が$N-K-1$個以下であることが必要十分であるから、$t$を固定したとき、$s,u$に対する条件は$l+s\le K,r+u\le N-K-1,s+u = c-t$である。 従って
\[\begin{align*} B_i &= \sum_{t=0}^c \sum_{\substack{s+u = c-t \\ 0\le s\le K-l \\ 0\le u \le (N-K-1)-r}} \binom{c}{s,t,u}(i-1)^s (M-i)^u \\ &= \sum_{t=0}^c \frac{c!}{t!} \sum_{\substack{s+u = c-t \\ 0\le s\le K-l \\ 0\le u \le (N-K-1)-r}} \frac{(i-1)^s}{s!}\frac{(M-i)^u}{u!} \end{align*}\]と書ける。この計算には$\mathcal{O} (N^2) $時間かかるが、列$f,g$を$f_j ={(i-1)^j}/{j!}$ $(0\le j \le K-l)$,$g_j ={(M-i)^j}/{j!}$ $ (0\le j \le (N-K-1)-r)$により定めると、$f,g$の畳み込み$h$を用いて$B_i = \sum_t (c!/t!) h_{c-t}$となるため、$\mathcal{O}(N\log N)$時間で計算できる。これを各$i$について計算することにより、$\mathcal{O}(MN\log N) $時間で問題を解くことができる。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include "template.hpp"
#include "math/factorial.hpp"
#include <atcoder/convolution>
#include <atcoder/modint>
using mint = atcoder::modint998244353;
factorial<mint> fac;
void solve() {
LL(n, m, k);
--k;
vl a(n);
input(a);
vl cnt(m + 1, 0);
rep(i, n) {
cnt[a[i]]++;
}
mint ans = 0;
ll c = cnt[0], l = 0;
rep(i, 1, m + 1) {
ll r = n - c - l - cnt[i];
ll fsize = k - l + 1, gsize = (n - k - 1) - r + 1;
if(fsize <= 0 or gsize <= 0) {
l += cnt[i];
continue;
}
vector<mint> f(fsize), g(gsize);
rep(j, fsize) {
// 上記の計算量のためにはpowを高速化する必要がある
f[j] = mint(i - 1).pow(j) * fac.inv(j);
}
rep(j, gsize) {
g[j] = mint(m - i).pow(j) * fac.inv(j);
}
auto h = atcoder::convolution(f, g);
rep(j, min(ssize(h), c + 1)) {
ans += fac[c] * fac.inv(c - j) * h[j] * i;
}
l += cnt[i];
}
ans /= mint(m).pow(c);
print(ans.val());
}
int main() {
ios::sync_with_stdio(false);
std::cin.tie(nullptr);
ll t = 1;
rep(_, t) solve();
}