LuoGu: CF1278F Cards
F. Cards
同样是将幂转化为下降幂和第二类斯特林数。设 $p = \frac 1 m, q = 1 - p$
$$
\begin {aligned}
ANS & = \sum _ {i = 0} ^ n \binom n i p ^ i q ^ {n - i} i ^ k \\
& = \sum _ {i = 0} ^ n \binom n i p ^ i q ^ {n - i} \sum _ {j = 0} ^ k {k \brace j} i ^ {\underline j} \\
& = \sum _ {j = 0} ^ k {k \brace j} j! \sum _ {i = 0} ^ n \binom n i \binom i j p ^ i q ^ {n - i}\\
& = \sum _ {j = 0} ^ k {k \brace j} j! \sum _ {i = 0} ^ n \binom n j \binom {n - j} {i - j} p ^ i q ^ {n - i}\\
& = \sum _ {j = 0} ^ k {k \brace j} j! \binom n j \sum _ {i = 0} ^ n \binom {n - j} {i - j} p ^ i q ^ {n - i}\\
& = \sum _ {j = 0} ^ k {k \brace j} j! \binom n j \sum _ {i = 0} ^ {n - j} \binom {n - j} i p ^ {i + j} q ^ {n - i - j}\\
& = \sum _ {j = 0} ^ k {k \brace j} n ^ {\underline j} p ^ j\sum _ {i = 0} ^ {n - j} \binom {n - j} i p ^ i q ^ {n - j - i}\\
& = \sum _ {j = 0} ^ k {k \brace j} n ^ {\underline j} p ^ j (p + q) ^ {n - j} \\
& = \sum _ {j = 0} ^ k {k \brace j} n ^ {\underline j} p ^ j
\end{aligned}
$$
查看代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| #include <cstdio> #include <algorithm> using namespace std; template <class Type> void read(Type &x) { char c; bool flag = false; while ((c = getchar()) < '0' || c > '9') c == '-' && (flag = true); x = c - '0'; while ((c = getchar()) >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - '0'; if (flag) x = ~x + 1; } template <class Type, class ...rest> void read(Type &x, rest &...y) { read(x), read(y...); } template <class Type> void write(Type x) { if (x < 0) putchar('-'), x = ~x + 1; if (x > 9) write(x / 10); putchar(x % 10 + '0'); } typedef long long LL; const int N = 5e3 + 10, mod = 998244353; int n, p, m, s[N][N]; int binpow (int b, int k = mod - 2) { int res = 1; for (; k; k >>= 1, b = (LL)b * b % mod) if (k & 1) res = (LL)res * b % mod; return res; } void init() { s[0][0] = 1; for (int i = 1; i <= m; ++i) for (int j = 1; j <= m; ++j) s[i][j] = (s[i - 1][j - 1] + (LL)j * s[i - 1][j]) % mod; } int main () { read(n, p, m); p = binpow(p); init(); int res = 0; for (int i = 0, t = 1; i <= min(n, m); t = (LL)t * (n - i++) % mod * p % mod) res = (res + (LL)t * s[m][i]) % mod; write(res); return 0; }
|