Blog of RuSun

\begin {array}{c} \mathfrak {One Problem Is Difficult} \\\\ \mathfrak {Because You Don't Know} \\\\ \mathfrak {Why It Is Diffucult} \end {array}

Code Festival 2017 Final J Tree MST

LuoGu: AT3611 Tree MST

AtCoder: J - Tree MST

两点 $a, b$ 的边权 $w _ a + w _ b + dis_ {a, b} = w _ a + w _ b + d _ a + d _ b - 2 d _ {lca} = (w _ a + d _ a) + (w _ b + d _ b) - 2 d _ {lca}$ 。

现在确定了 $lca$ ,如果对于每个子树,都各自内部连通了,那么考虑将这些子树连通。对于每一个点 $a$ ,$w _ a + d _ a$ 是确定的,$-2d _ {lca}$ 也是确定的,我们希望在子树外找到一个点 $b$ 满足 $w _ b + d _ b$ ,那么可以考虑将子树中最小的点和其他所有的点连边,这样在原来的 $n ^ 2$ 条边中,只保留了其中有效的边。

确定 $lca$ 可以点分治,这样可以保证选择的边有 $n \log n$ 个。

最后做一次最小生成树即可。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
template <class Type>
void read (Type &x)
{
char c;
bool flag = false;
while ((c = getchar()) < '0' || c > '9')
c == '-' && (flag = true);
x = c - '0';
while ((c = getchar()) >= '0' && c <= '9')
x = (x << 1) + (x << 3) + c - '0';
flag && (x = ~x + 1);
}
template <class Type>
void write (Type x)
{
x < 0 && (putchar('-'), x = ~x + 1);
x > 9 && (write(x / 10), 0);
putchar('0' + x % 10);
}
template <class Type>
void chkmax (Type &x, Type k)
{
k > x && (x = k);
}
typedef long long LL;
const int N = 2e5 + 10, M = 4e5 + 10;
bool vis[N];
int n, w[N];
LL d[N];
int idx, hd[N], nxt[M], edg[M], wt[M];
namespace DSU
{
int p[N];
int fa (int x)
{
return p[x] == x ? x : p[x] = fa(p[x]);
}
void init()
{
for (int i = 1; i <= n; i++)
p[i] = i;
}
bool check (int a, int b)
{
return fa(a) ^ fa(b);
}
void merge (int a, int b)
{
p[fa(a)] = fa(b);
}
}
struct Edge
{
int u, v;
LL w;
bool operator < (const Edge &_) const
{
return w < _.w;
}
};
vector <Edge> e;
void add (int a, int b, int c)
{
nxt[++idx] = hd[a];
hd[a] = idx;
edg[idx] = b;
wt[idx] =c;
}
int Size (int x, int fa)
{
if (vis[x])
return 0;
int res = 1;
for (int i = hd[x]; ~i; i = nxt[i])
edg[i] ^ fa && (res += Size(edg[i], x));
return res;
}
int WeightCentre (int x, int fa, int tot, int &wc)
{
if (vis[x])
return 0;
int sum = 1, mx = 0;
for (int i = hd[x]; ~i; i = nxt[i])
if (edg[i] ^ fa)
{
int t = WeightCentre(edg[i], x, tot, wc);
chkmax(mx, t), sum += t;
}
chkmax(mx, tot - sum);
if (mx <= tot / 2)
wc = x;
return sum;
}
void Dist (int x, int fa, LL s, vector <int> &p)
{
if (vis[x])
return;
d[x] = s + w[x];
p.push_back(x);
for (int i = hd[x]; ~i; i = nxt[i])
edg[i] ^ fa && (Dist(edg[i], x, s + wt[i], p), 0);
}
void calc (int x)
{
if (vis[x])
return;
WeightCentre(x, -1, Size(x, 0), x);
vis[x] = true;
vector <int> p;
for (int i = hd[x]; ~i; i = nxt[i])
Dist(edg[i], x, wt[i], p);
d[x] = w[x];
p.push_back(x);
int t = -1;
for (int i : p)
(t == -1 || d[i] < d[t]) && (t = i);
e.push_back((Edge){t, x, d[t] + w[x]});
for (int i : p)
e.push_back((Edge){t, i, d[t] + d[i]});
for (int i = hd[x]; ~i; i = nxt[i])
calc(edg[i]);
}
int main ()
{
read(n);
for (int i = 1; i <= n; i++)
hd[i] = -1;
for (int i = 1; i <= n; i++)
read(w[i]);
for (int i = 1, a, b, c; i < n; i++)
{
read(a), read(b), read(c);
add(a, b, c), add(b, a, c);
}
calc(1);
sort(e.begin(), e.end());
LL res = 0;
DSU::init();
for (Edge i : e)
if (DSU::check(i.u, i.v))
{
res += i.w;
DSU::merge(i.u, i.v);
}
write(res);
return 0;
}

另有 Boruvka 算法可以做到 $n \log n$ 。