1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
| #include <cstdio> #include <vector> #include <algorithm> using namespace std; typedef long long LL; template <class Type> void read(Type &x) { char c; bool flag = false; while ((c = getchar()) < '0' || c > '9') c == '-' && (flag = true); x = c - '0'; while ((c = getchar()) >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - '0'; flag && (x = ~x + 1); } template <class Type> void write(Type x) { x < 0 && (putchar('-'), x = ~x + 1); x > 9 && (write(x / 10), 0); putchar(x % 10 + '0'); } const int N = 2e5 + 10, mod = 998244353; int inv[N], fact[N], ifact[N]; int n, m, L, A[N], B[N], cnt[N << 1], tr[N << 1]; void add(int x, int k) { for (; x <= L; x += x & -x) tr[x] += k; } int query(int x) { int res = 0; for (; x; x -= x & -x) res += tr[x]; return res; } void init() { inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = -(LL)(mod / i) * inv[mod % i] % mod; fact[0] = ifact[0] = 1; for (int i = 1; i < N; i++) { fact[i] = (LL)fact[i - 1] * i % mod; ifact[i] = (LL)ifact[i - 1] * inv[i] % mod; } } int main() { read(n), read(m); init(); vector<int> ws; for (int i = 1; i <= n; i++) read(A[i]), ws.push_back(A[i]); for (int i = 1; i <= m; i++) read(B[i]), ws.push_back(B[i]); sort(ws.begin(), ws.end()); L = ws.erase(unique(ws.begin(), ws.end()), ws.end()) - ws.begin(); for (int i = 1; i <= n; i++) A[i] = lower_bound(ws.begin(), ws.end(), A[i]) - ws.begin() + 1; for (int i = 1; i <= m; i++) B[i] = lower_bound(ws.begin(), ws.end(), B[i]) - ws.begin() + 1; int res = 0, cur = fact[n]; for (int i = 1; i <= n; i++) cnt[A[i]]++; for (int i = 1; i <= L; i++) cur = (LL)cur * ifact[cnt[i]] % mod; for (int i = 1; i <= L; i++) add(i, cnt[i]); bool flag = true; for (int i = 1; i <= min(n, m); i++) { cur = (LL)cur * inv[n - i + 1] % mod; (res += (LL)query(B[i] - 1) * cur % mod) %= mod; if (!cnt[B[i]]) { flag = false; break; } add(B[i], -1); cur = (LL)cur * (cnt[B[i]]--) % mod; } printf("%d", (res + (n < m && flag) + mod) % mod); return 0; }
|
Gitalk 加载中 ...