1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| #include <cstdio> #include <cctype> #include <algorithm> #include <cstring> #include <iostream> #include <cassert> #include <cmath> using namespace std; typedef long long ll; template <typename _Tp> void read(_Tp &a, char c = 0, int f = 1) { for(c = getchar(); !isdigit(c); c = getchar()) if(c == '-') f = -1; for(a = 0; isdigit(c); a = a * 10 + c - '0', c = getchar()); a *= f; } template <typename _Tp> void write(_Tp a) { if(a < 0) putchar('-'), a = -a; if(a > 9) write(a / 10); putchar(a % 10 + '0'); }
const int N = 1e5 + 5; const int M = 2 * N;
int n, k, hd[N], nxt[M], to[M], w[M], tot; int a[N];
void add(int u, int v, int c) { nxt[++tot] = hd[u], hd[u] = tot, to[tot] = v, w[tot] = c; }
bool v1[N], v2[N];
bool dfs(int u, int fa) { bool f = 0; for(int e = hd[u]; e; e = nxt[e]) { int v = to[e]; if(v == fa) continue; if(dfs(v, u) || v1[v]) { f = 1; } } if(f) v2[u] = 1; return v2[u]; }
int main() { read(n), read(k); for(int i = 1; i < n; i++) { int u, v, c; read(u), read(v), read(c); add(u, v, c); add(v, u, c); } for(int i = 1; i <= k; i++) { read(a[i]); v1[a[i]] = 1; } for(int i = 1; i <= n; i++) { if(v1[i]) { dfs(i, 0); break; } } int ans = 0; for(int i = 1; i <= n; i++) { if(v1[i] || v2[i]) ans++; } write(ans), putchar('\n'); return 0; }
|