1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| #include <cstdio> #include <vector> #include <algorithm> #define pb push_back using namespace std; template <class Type> void read (Type &x) { char c; bool flag = false; while ((c = getchar()) < '0' || c > '9') c == '-' && (flag = true); x = c - '0'; while ((c = getchar()) >= '0' && c <= '9') x = (x << 1) + (x << 3) + c - '0'; flag && (x = ~x + 1); } template <class Type, class ...rest> void read (Type &x, rest &...y) { read(x), read(y...); } template <class Type> void write (Type x) { x < 0 && (putchar('-'), x = ~x + 1); x > 9 && (write(x / 10), 0); putchar('0' + x % 10); } typedef long long LL; const int N = 4e5 + 10, M = 22; LL ans; char str[N]; int top, stk[N]; vector <int> g[N], e[N]; int n, q, last = 1, cnt = 1, h[N], f[N][2], s[N]; int lg[N << 1], d[N], stmp, id[N], st[N << 1][M]; struct Node { int p, len, nxt[26]; } tr[N]; void extend (int c) { int p = last, np = ++cnt; tr[np].len = tr[p].len + 1; for (; p && !tr[p].nxt[c]; p = tr[p].p) tr[p].nxt[c] = np; last = np; if (!p) return void(tr[np].p = 1); int q = tr[p].nxt[c]; if (tr[q].len == tr[p].len + 1) return void(tr[np].p = q); int nq = ++cnt; tr[nq] = tr[q], tr[nq].len = tr[p].len + 1; tr[q].p = tr[np].p = nq; for (; p && tr[p].nxt[c] == q; p = tr[p].p) tr[p].nxt[c] = nq; } int dmin (int a, int b) { return d[a] < d[b] ? a : b; } void dfs (int u) { st[id[u] = ++stmp][0] = u; for (int v : g[u]) d[v] = d[u] + 1, dfs(v), st[++stmp][0] = u; } void init () { dfs(1); for (int i = 2; i <= stmp; ++i) lg[i] = lg[i >> 1] + 1; for (int k = 0; k < lg[stmp]; ++k) for (int i = stmp + 1 - (1 << k + 1); i; --i) st[i][k + 1] = dmin(st[i][k], st[i + (1 << k)][k]);
} int lca (int a, int b) { if ((a = id[a]) > (b = id[b])) swap(a, b); int k = lg[b - a + 1]; return dmin(st[a][k], st[b - (1 << k) + 1][k]); } void dp (int u) { for (int v : e[u]) { dp(v); f[u][0] += f[v][0], f[u][1] += f[v][1]; ans -= (LL)f[v][0] * f[v][1] * tr[u].len; } ans += (LL)f[u][0] * f[u][1] * tr[u].len; } int main () { read(n, q); scanf("%s", str); for (int i = n - 1; ~i; --i) s[i + 1] = cnt + 1, extend(str[i] - 'a'); for (int i = 2; i <= cnt; ++i) g[tr[i].p].pb(i); init(); for (int k[2]; q; --q) { read(k[0], k[1]); vector <int> h; h.reserve(k[0] + k[1]); for (int t = 0, a; t < 2; ++t) for (int i = 1; i <= k[t]; ++i) read(a), a = s[a], h.pb(a), f[a][0] = f[a][1] = 0; for (int i = 0; i < k[0]; ++i) ++f[h[i]][0]; for (int i = k[0]; i < k[0] + k[1]; ++i) ++f[h[i]][1]; sort(h.begin(), h.end(), [&](int a, int b) { return id[a] < id[b]; }); h.erase(unique(h.begin(), h.end()), h.end()); stk[top = 1] = 1, e[1].clear(), f[1][0] = f[1][1] = 0; for (int i : h) { int t = lca(i, stk[top]); if (t ^ stk[top]) { for (; id[t] < id[stk[top - 1]]; --top) e[stk[top - 1]].pb(stk[top]); if (id[t] > id[stk[top - 1]]) e[t].clear(), e[t].pb(stk[top]), stk[top] = t, f[t][0] = f[t][1] = 0; else e[stk[top - 1]].pb(stk[top]), --top; } stk[++top] = i, e[i].clear(); } for (; top > 1; --top) e[stk[top - 1]].pb(stk[top]); ans = 0; dp(1); write(ans), puts(""); } return 0; }
|