1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
   | #include <cstdio> #include <algorithm> #include <vector> using namespace std; const int N = 4e4 + 10, M = 8e4 + 10; bool vis[N]; int n, m; int idx, hd[N], nxt[M], edg[M], wt[M]; int Size(int x, int fa) {     if (vis[x])         return 0;     int res = 1;     for (int i = hd[x]; ~i; i = nxt[i])         if (edg[i] != fa)             res += Size(edg[i], x);     return res; } int WeightCentre(int x, int fa, int tot, int &wc) {     if (vis[x])         return 0;     int sum = 1, mx = 0;     for (int i = hd[x]; ~i; i = nxt[i])         if (edg[i] != fa)         {             int t = WeightCentre(edg[i], x, tot, wc);             mx = max(mx, t);             sum += t;         }     mx = max(mx, tot - sum);     if (mx <= tot / 2)         wc = x;     return sum; } void Dist(int x, int fa, int d, vector<int> &q) {     if (vis[x])         return;     q.push_back(d);     for (int i = hd[x]; ~i; i = nxt[i])         if (edg[i] != fa)             Dist(edg[i], x, d + wt[i], q); } int get(vector<int> k) {     sort(k.begin(), k.end());     int res = 0;     for (int i = k.size() - 1, j = -1; i >= 0; i--)     {         while (j + 1 < i && k[j + 1] + k[i] <= m)             j++;         j = min(j, i - 1);         res += j + 1;     }     return res; } int cal(int x) {     if (vis[x])         return 0;     WeightCentre(x, -1, Size(x, -1), x);     vis[x] = true;     int res = 0;     vector<int> p;     for (int i = hd[x]; ~i; i = nxt[i])     {         vector<int> q;         Dist(edg[i], -1, wt[i], q);         res -= get(q);         for (int k : q)         {             if (k <= m)                 res++;             p.push_back(k);         }     }     res += get(p);     for (int i = hd[x]; ~i; i = nxt[i])         res += cal(edg[i]);     return res; } void add(int a, int b, int c) {     nxt[++idx] = hd[a];     hd[a] = idx;     edg[idx] = b;     wt[idx] = c; } int main() {     scanf("%d", &n);     for (int i = 1; i <= n; i++)         hd[i] = -1;     for (int i = 1, a, b, c; i < n; i++)     {         scanf("%d%d%d", &a, &b, &c);         add(a, b, c);         add(b, a, c);     }     scanf("%d", &m);     printf("%d", cal(1));     return 0; }
   |