Blog of RuSun

\begin {array}{c} \mathfrak {One Problem Is Difficult} \\\\ \mathfrak {Because You Don't Know} \\\\ \mathfrak {Why It Is Diffucult} \end {array}

P3723 [AH2017/HNOI2017]礼物

P3723 [AH2017/HNOI2017]礼物

注意到 $m$ 很小,考虑直接枚举改变的亮度,这样,只要确定了顺序,就可以在 $O(nm)$ 的范围内计算,这意味着我们需要很快得确定顺序。

假设顺序已经确定,先推式子,设改变的亮度为 $c$ ,则:
$$
\begin {aligned}
& \sum _ {i = 0} ^ {n - 1} (x _ i - y _ i + c) ^ 2 \\
= & \sum _ {i = 0} ^ {n - 1} ((x _ i - y _ i) ^ 2 + 2(x _ i - y _ i) c + c ^ 2) \\
= & \sum _ {i = 0} ^ {n - 1} (x _ i ^ 2 + y ^ 2) + 2c (\sum _ {i = 0} ^ {n - 1} x _ i - \sum _ {i = 1} ^ {n - 1} y _ i) + c ^ 2 - 2 \sum _ {i = 0} ^ {n - 1} x _ i y _ i
\end {aligned}
$$
可以发现前面的都可以确定了,我们需要快速确定 $\displaystyle \sum _ {i = 0} ^ {n - 1} x _ i y _ i$ 的最大值。

顺序是可以随便转,所以 $\displaystyle \sum _ {i = 0} ^ {n - 1} x _ i y _ i = \sum _ {i = 0} ^ {n - 1} x _ i y _ {n - i - 1} = x * y [n - 1]$ 。考虑将 $x$ 破链成环,做一次卷积,所有的顺序的答案就可以都算出来了。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <cstdio>
#include <algorithm>
using namespace std;
template <class Type>
void read(Type &x)
{
char c;
bool flag = false;
while ((c = getchar()) < '0' || c > '9')
c == '-' && (flag = true);
x = c - '0';
while ((c = getchar()) >= '0' && c <= '9')
x = (x << 3) + (x << 1) + c - '0';
flag && (x = ~x + 1);
}
template <class Type>
void write(Type x)
{
x < 0 && (putchar('-'), x = ~x + 1);
x > 9 && (write(x / 10), 0);
putchar(x % 10 + '0');
}
typedef long long LL;
const int N = 5e4 + 10, mod = 998244353;
int n, m, A[N << 1], B[N], rev[N << 4];
int binpow(int b, int k)
{
int res = 1;
while (k)
{
if (k & 1)
res = (LL)res * b % mod;
b = (LL)b * b % mod;
k >>= 1;
}
return res;
}
void ntt(int *x, int bit, int op)
{
int tot = 1 << bit;
for (int i = 1; i < tot; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
for (int i = 0; i < tot; i++)
i < rev[i] && (swap(x[i], x[rev[i]]), 0);
for (int mid = 1; mid < tot; mid <<= 1)
{
int w1 = binpow(3, (mod - 1) / (mid << 1));
op == -1 && (w1 = binpow(w1, mod - 2));
for (int i = 0; i < tot; i += mid << 1)
for (int j = 0, cur = 1; j < mid; j++, cur = (LL)cur * w1 % mod)
{
int p = x[i + j], q = (LL)x[i + j + mid] * cur % mod;
x[i + j] = (p + q) % mod, x[i + j + mid] = (p - q + mod) % mod;
}
}
if (~op)
return;
int intot = binpow(tot, mod - 2);
for (int i = 0; i < tot; i++)
x[i] = (LL)x[i] * intot % mod;
}
void PolyMul(int *x, int a, int *y, int b, int *g)
{
int bit = 0;
while ((1 << bit) < a + b - 1)
bit++;
ntt(x, bit, 1), ntt(y, bit, 1);
for (int i = 0; i < (1 << bit); i++)
g[i] = (LL)x[i] * y[i] % mod;
ntt(g, bit, -1);
}
int cal(int x)
{
int res = 0;
for (int i = 0; i < n; i++)
res += A[i] * A[i] + B[i] * B[i] + x * x + 2 * x * A[i] - 2 * x * B[i];
return res;
}
void chkmin(int &x, int k)
{
(x > k) && (x = k);
}
void chkmax(int &x, int k)
{
(x < k) && (x = k);
}
int main()
{
read(n), read(m);
for (int i = 0; i < n; i++)
read(A[i]), A[i + n] = A[i];
for (int i = 0; i < n; i++)
read(B[i]);
reverse(B, B + n);
static int a[N << 4], b[N << 4], c[N << 4];
for (int i = 0; i < n + n; i++)
a[i] = A[i];
for (int i = 0; i < n; i++)
b[i] = B[i];
PolyMul(a, n + n, b, n, c);
int res = 2e9;
for (int i = -m; i <= m; i++)
chkmin(res, cal(i));
int mx = 0;
for (int i = n - 1; i < n + n - 1; i++)
chkmax(mx, c[i]);
write(res - mx * 2);
return 0;
}