1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
   |  #include <cstdio> #include <cctype> #include <algorithm> #include <cstring> #include <iostream> #include <cassert> #include <cmath> using namespace std; typedef long long ll;   template <typename _Tp> void read(_Tp &a, char c = 0, int f = 1) {     for(c = getchar(); !isdigit(c); c = getchar()) if(c == '-') f = -1;     for(a = 0; isdigit(c); a = a * 10 + c - '0', c = getchar()); a *= f; }   template <typename _Tp> void write(_Tp a) {     if(a < 0) putchar('-'), a = -a;     if(a > 9) write(a / 10); putchar(a % 10 + '0'); }
  const int N = 1e5 + 5; const int M = 2 * N;
  int n, k, hd[N], nxt[M], to[M], w[M], tot; int a[N];
  void add(int u, int v, int c) {     nxt[++tot] = hd[u], hd[u] = tot, to[tot] = v, w[tot] = c; }
  bool v1[N], v2[N];
  bool dfs(int u, int fa) {      bool f = 0;     for(int e = hd[u]; e; e = nxt[e]) {         int v = to[e];         if(v == fa) continue;         if(dfs(v, u) || v1[v]) {             f = 1;         }     }     if(f) v2[u] = 1;     return v2[u]; }
  int main() {          read(n), read(k);     for(int i = 1; i < n; i++) {         int u, v, c;         read(u), read(v), read(c);         add(u, v, c);         add(v, u, c);     }     for(int i = 1; i <= k; i++) {         read(a[i]);         v1[a[i]] = 1;     }     for(int i = 1; i <= n; i++) {         if(v1[i]) {             dfs(i, 0);             break;         }     }     int ans = 0;     for(int i = 1; i <= n; i++) {         if(v1[i] || v2[i]) ans++;     }     write(ans), putchar('\n');     return 0; }
 
  |