Blog of RuSun

\begin {array}{c} \mathfrak {One Problem Is Difficult} \\\\ \mathfrak {Because You Don't Know} \\\\ \mathfrak {Why It Is Diffucult} \end {array}

P4983 忘情

P4983 忘情

原式化简为 $\displaystyle (\sum _ {i = 1} ^ n x _ i + 1) ^ 2$ 。做前缀和后,区间 $(k, i]$ 的值为 $(s _ i - s _ k + 1) ^ 2$ 。这个式子可以斜率优化。复杂度为 $O(nm)$ ,可以得到 $50pts$ 。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
LL s[N], f[2][N];
int hd, tl, q[N];
int n, m;
LL Y(int x, int a)
{
return f[x & 1][a] + s[a] * s[a] - 2 * s[a];
}
double slp(int x, int a, int b)
{
return (double)(Y(x, a) - Y(x, b)) / (double)(s[a] - s[b]);
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%lld", &s[i]);
s[i] += s[i - 1];
}
for (int i = 1; i <= n; i++)
f[1][i] = (s[i] + 1) * (s[i] + 1);
for (int i = 2; i <= m; i++)
{
hd = 1, tl = 0;
q[++tl] = i - 1;
for (int j = i; j <= n; j++)
{
while (hd < tl && slp(i - 1, q[hd], q[hd + 1]) < s[j] * 2)
hd++;
f[i & 1][j] = f[i - 1 & 1][q[hd]] + (s[j] - s[q[hd]] + 1) * (s[j] - s[q[hd]] + 1);
while (hd < tl && slp(i - 1, q[tl], q[tl - 1]) > slp(i - 1, q[tl], j))
tl--;
q[++tl] = j;
}
}
printf("%lld\n", f[m & 1][n]);
return 0;
}

进一步,可以发现,随着段数的增多,答案越来越小,并且小得越来越慢,于是可以考虑 wqs二分 。每多一段都需要多一个权值。此时不考虑段数,依然可以斜率优化。复杂度为 $O(n \log x)$ 。

细节:此题也会出现斜率不严格单调增加的情况。斜率优化中,一条直线上的若干点都可以成为决策,如果选择了第一个,则选择了最先入队的点, $cnt$ 最小;如果选择了最后一个,则选择了最后入队的点, $cnt$ 最大。对应到 wqs 中,同样的斜率,选择了 $cnt$ 最小的点,那么 $cnt \le m$ 时就需要更新答案;反之,则 $cnt \ge m$ 更新答案。具体地,如果斜率优化中的斜率比较取了等,则为后者。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
LL s[N], f[N];
int hd, tl, q[N];
int n, m, cnt[N];
LL Y(int a)
{
return f[a] + s[a] * s[a] - 2 * s[a];
}
double slp(int a, int b)
{
return (double)(Y(a) - Y(b)) / (double)(s[a] - s[b]);
}
void check(LL x)
{
hd = 1, tl = 0;
q[++tl] = 0;
for (int i = 1; i <= n; i++)
{
while (hd < tl && slp(q[hd], q[hd + 1]) < s[i] * 2)
hd++;
cnt[i] = cnt[q[hd]] + 1;
f[i] = f[q[hd]] + (s[i] - s[q[hd]] + 1) * (s[i] - s[q[hd]] + 1) - x;
while (hd < tl && slp(q[tl], q[tl - 1]) > slp(q[tl], i))
tl--;
q[++tl] = i;
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%lld", &s[i]);
s[i] += s[i - 1];
}
LL l = -1e16, r = 0, res;
while (l < r)
{
LL mid = l + r >> 1;
check(mid);
cnt[n] <= m ? (res = f[n] + m * mid, l = mid + 1) : r = mid - 1;
}
printf("%lld", res);
return 0;
}