@[xingyuliu](/user/590925)
```
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 6e4 + 3;
int ver[2 * N], nxt[2 * N], head[N * 2], tot;
long long tim;
int n, m, a[N];
int f[N][20];
long long dis[N][20];
int at, cntc, cntn;
bool ha[N];
long long edg[2 * N];
void add(int x, int y, long long z){
ver[++tot] = y;
edg[tot] = z;
nxt[tot] = head[x];
head[x] = tot;
}
void dfs1(int x, int fa){
for(int i = head[x]; i; i = nxt[i]){
int y = ver[i], z = edg[i];
if(y == fa)
continue;
f[y][0] = x;
dis[y][0] = z;
for(int j = 1; j < 20; j++){
f[y][j] = f[f[y][j-1]][j-1];
dis[y][j] = dis[y][j-1] + dis[f[y][j-1]][j-1];
}
dfs1(y, x);
}
}
struct node{
int x, y;
bool us;
}cu[N], need[N];
bool cmp1(node xx, node yy){
if(xx.x != yy.x)
return xx.x < yy.x;
return xx.y < yy.y;
}
bool cmp2(node xx, node yy){
if(xx.us != yy.us)
return xx.us < yy.us;
return xx.y > yy.y;
}
bool dfs2(int x, int fa){
if(ha[x])
return 0;
if(nxt[head[x]] == 0)
return 1;
bool res = 0;
for(int i = head[x]; i; i = nxt[i]){
int y = ver[i];
if(y == fa)
continue;
res = res | dfs2(y, x);
}
return res;
}
bool chk(long long mid){
memset(ha, 0, sizeof(ha));
cntc = cntn = 0;
for(int i = 1; i <= m; i++){
tim = 0;
at = a[i];
for(int j = 19; j >= 0; j--)
if(tim + dis[at][j] <= mid && f[at][j] > 1){
tim += dis[at][j];
at = f[at][j];
}
if(f[at][0] == 1 && dis[at][0] + tim < mid)
cu[++cntc] = (node){at, mid - tim - dis[at][0], 0};
else
ha[at] = 1;
}
for(int i = head[1]; i; i = nxt[i]){
int y = ver[i];
bool fl = dfs2(y, 1);
if(fl)
need[++cntn] = (node){y, dis[y][0], 0};
}
sort(need + 1, need + cntn + 1, cmp1);
sort(cu + 1, cu + cntc + 1, cmp1);
int now = 1;
if(cntc < cntn)
return 0;
for(int i = 1; i <= cntn; i++){
while(need[i].x > cu[now].x)
now++;
if(need[i].x == cu[now].x){
need[i].us = 1;
cu[now].us = 1;
now++;
}
}
sort(need + 1, need + cntn + 1, cmp2);
sort(cu + 1, cu + cntc + 1, cmp2);
for(int i = 1; i <= cntn && need[i].us == 0; i++){
if(cu[i].us == 1)
return 0;
if(cu[i].y < need[i].y)
return 0;
}
return 1;
}
main(){
long long l = 0, r = 0;
scanf("%d", &n);
for(int i = 1; i < n ; i++) {
int u, v;
long long w;
scanf("%lld %lld %lld", &u, &v, &w);
add(u, v, w);
add(v, u, w);
r += w;
}
// cout << r << endl;
dfs1(1, 0);
scanf("%lld", &m);
for(int i = 1; i <= m; i++)
scanf("%lld", &a[i]);
r++;
long long t = r;
while(l < r){
long long mid = (l + r) >> 1;
if(chk(mid))
r = mid;
else
l = mid + 1;
}
if(r == t)
puts("-1");
else
printf("%lld\n", l);
return 0;
}
```
by ninji @ 2023-08-16 22:05:39
开O2,80分
by ninji @ 2023-08-16 22:06:07
@[xingyuliu](/user/590925) 剩下的你可以看评论区
by ninji @ 2023-08-16 22:06:25
%%%
by _x_y_ @ 2023-08-17 07:53:48