1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| #include <cstdio> using namespace std; const int N = 4e3 + 10, inf = 144e6; int n, m, s[N][N], g[N][N]; struct Data { int t, s; friend bool operator<(const Data &x, const Data &y) { return x.s < y.s || (x.s == y.s && x.t < y.t); } friend Data operator+(const Data &x, int k) { return (Data){x.t + 1, x.s + k}; } } f[N]; Data cal(int k, int i) { return f[k] + g[k + 1][i]; } struct Node { int p, l, r; } q[N]; int hd, tl; void check(int x) { hd = 1, tl = 0; q[++tl] = (Node){0, 1, n}; for (int i = 1; i <= n; i++) { while (hd <= tl && q[hd].r < i) hd++; f[i] = cal(q[hd].p, i); f[i].s += x; int p = n + 1; while (hd <= tl && cal(i, q[tl].l) < cal(q[tl].p, q[tl].l)) p = q[tl--].l; if (hd <= tl && cal(i, q[tl].r) < cal(q[tl].p, q[tl].r)) { int l = q[tl].l, r = q[tl].r; while (l <= r) { int mid = l + r >> 1; cal(i, mid) < cal(q[tl].p, mid) ? (p = mid, r = mid - 1) : l = mid + 1; } q[tl].r = p - 1; } if (p <= n) q[++tl] = (Node){i, p, n}; } } void read(int &x) { x = 0; char c = getchar(); while (c < '0' || c > '9') c = getchar(); while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); } } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { read(s[i][j]); s[i][j] += s[i][j - 1]; } for (int k = 1; k <= n; k++) for (int i = 1, j = k + 1; j <= n; i++, j++) g[i][j] = g[i][j - 1] + s[j][j] - s[j][i - 1]; int l = 0, r = inf, res; while (l <= r) { int mid = l + r >> 1; check(mid); f[n].t <= m ? (res = f[n].s - mid * m, r = mid - 1) : l = mid + 1; } printf("%d", res); return 0; }
|