Blog of RuSun

\begin {array}{c} \mathfrak {One Problem Is Difficult} \\\\ \mathfrak {Because You Don't Know} \\\\ \mathfrak {Why It Is Diffucult} \end {array}

P2486 [SDOI2011]染色

P2486 [SDOI2011]染色

树上操作,立刻想到了树剖,考虑两段区间如何合并。

颜色段的定义是极长的连续相同颜色被认为是一段。

对于两段区间,只有交界的两个点是连续的,其他的对答案没有影响,所以,额外维护一段区间两段的颜色。合并时,将两段的答案相加,如果交界的颜色是一样的,则答案减一。

现在考虑树上查询时如何计算答案。在查询 $x$ 到 $y$ 的路径时,会一直选择 $x$ 和 $y$ 中深度较大的一个向上跳,而其中的 $x$ 到 $lca$ 和 $y$ 到 $lca$ 的路径( $lca$ 特判)是相对独立的,所以查询线段树上的区间后,维护树上的深度较小的点的颜色,即 $dfn$ 较小的点的颜色,线段树上区间的左端点。下一段树上的链的深度较大的点是和该点交界的,所以如果颜色一样,答案减一。

实际上也直接用线段树的 $pushup$ 操作合并,但是具体操作可能稍微麻烦一点。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1e5 + 10, M = 2e5 + 10;
int n, m, rt, w[N], d[N], p[N];
int idx, hd[N], nxt[M], edg[M];
int stmp, sz[N], son[N], top[N], rnk[N], dfn[N];
struct Node
{
int l, r, tag;
int s, lc, rc;
} tr[N << 2];
void pushup(int x)
{
tr[x].s = tr[x << 1].s + tr[x << 1 | 1].s;
if (tr[x << 1].rc == tr[x << 1 | 1].lc)
tr[x].s--;
tr[x].lc = tr[x << 1].lc;
tr[x].rc = tr[x << 1 | 1].rc;
}
void pushdown(int x)
{
tr[x << 1].s = tr[x << 1 | 1].s = 1;
tr[x << 1].lc = tr[x << 1].rc = tr[x].tag;
tr[x << 1 | 1].lc = tr[x << 1 | 1].rc = tr[x].tag;
tr[x << 1].tag = tr[x << 1 | 1].tag = tr[x].tag;
tr[x].tag = 0;
}
void build(int x, int l, int r)
{
tr[x].l = l;
tr[x].r = r;
if (l == r)
{
tr[x].s = 1;
tr[x].lc = tr[x].rc = w[rnk[l]];
return;
}
int mid = l + r >> 1;
build(x << 1, l, mid);
build(x << 1 | 1, mid + 1, r);
pushup(x);
}
void modify(int x, int l, int r, int k)
{
if (tr[x].l >= l && tr[x].r <= r)
{
tr[x].s = 1;
tr[x].lc = tr[x].rc = k;
tr[x].tag = k;
return;
}
if (tr[x].tag)
pushdown(x);
int mid = tr[x].l + tr[x].r >> 1;
if (l <= mid)
modify(x << 1, l, r, k);
if (r > mid)
modify(x << 1 | 1, l, r, k);
pushup(x);
}
Node query(int x, int l, int r)
{
if (tr[x].l >= l && tr[x].r <= r)
return tr[x];
if (tr[x].tag)
pushdown(x);
int mid = tr[x].l + tr[x].r >> 1;
if (r <= mid)
return query(x << 1, l, r);
else if (l > mid)
return query(x << 1 | 1, l, r);
Node res, lres = query(x << 1, l, r), rres = query(x << 1 | 1, l, r);
res.lc = lres.lc, res.rc = rres.rc, res.s = lres.s + rres.s;
if (lres.rc == rres.lc)
res.s--;
return res;
}
void ModifyPath(int x, int y, int k)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x, y);
modify(1, dfn[top[x]], dfn[x], k);
x = p[top[x]];
}
if (d[x] > d[y])
swap(x, y);
modify(1, dfn[x], dfn[y], k);
}
int QueryPath(int x, int y)
{
int res = 0, xlc = -1, ylc = -1;
while (top[x] != top[y])
{
if (d[top[x]] > d[top[y]])
{
Node q = query(1, dfn[top[x]], dfn[x]);
res += q.s;
if (xlc == q.rc)
res--;
xlc = q.lc;
x = p[top[x]];
}
else
{
Node q = query(1, dfn[top[y]], dfn[y]);
res += q.s;
if (ylc == q.rc)
res--;
ylc = q.lc;
y = p[top[y]];
}
}
if (d[x] > d[y])
{
swap(x, y);
swap(xlc, ylc);
}
Node q = query(1, dfn[x], dfn[y]);
res += q.s;
if (xlc == q.lc)
res--;
if (ylc == q.rc)
res--;
return res;
}
void dfs1(int x)
{
son[x] = -1;
sz[x] = 1;
for (int i = hd[x]; ~i; i = nxt[i])
if (!d[edg[i]])
{
d[edg[i]] = d[x] + 1;
p[edg[i]] = x;
dfs1(edg[i]);
sz[x] += sz[edg[i]];
if (son[x] == -1 || sz[edg[i]] > sz[son[x]])
son[x] = edg[i];
}
}
void dfs2(int x, int t)
{
top[x] = t;
dfn[x] = ++stmp;
rnk[stmp] = x;
if (son[x] == -1)
return;
dfs2(son[x], t);
for (int i = hd[x]; ~i; i = nxt[i])
if (edg[i] != son[x] && edg[i] != p[x])
dfs2(edg[i], edg[i]);
}
void add(int a, int b)
{
nxt[++idx] = hd[a];
hd[a] = idx;
edg[idx] = b;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
hd[i] = -1;
for (int i = 1; i <= n; i++)
scanf("%d", &w[i]);
for (int i = 1, u, v; i < n; i++)
{
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
d[rt = 1] = 1;
dfs1(rt);
dfs2(rt, rt);
build(1, 1, n);
for (char op; m; m--)
{
scanf("%c", &op);
while (op != 'C' && op != 'Q')
scanf("%c", &op);
if (op == 'C')
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
ModifyPath(a, b, c);
}
if (op == 'Q')
{
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", QueryPath(a, b));
}
}
return 0;
}