2

· · 题解

别的部分其它题解讲的已经很详细了,这里给出建虚树一个更好写的做法。

首先记关键点序列为 a,将 a\text{dfn} 升序排序,将 a 中所有点和每个 \text{lca}(a_i, a_{i + 1}) 扔到一个序列 p 中,去重。

然后对于 p,再次按 \text{dfn} 升序排序,并将 \text{lca}(p_i, p_{i + 1})p_{i + 1} 连边,这样就建好了虚树。

::::info[参考实现]

        for (int i = 1; i <= k; i++) cin >> a[i], b[a[i]] = 1;
        sort(a + 1, a + k + 1, [&](const int &u, const int &v) { return dfn[u] < dfn[v]; });
        len = 0;
        for (int i = 1; i < k; i++) {
            p[++len] = a[i];
            p[++len] = lca(a[i], a[i + 1]);
        }
        p[++len] = a[k];
        sort(p + 1, p + len + 1, [&](const int &u, const int &v) { return dfn[u] < dfn[v]; });
        len = unique(p + 1, p + len + 1) - p - 1;
        for (int i = 1; i < len; i++) {
            int l = lca(p[i], p[i + 1]);
            e[l].push_back(p[i + 1]);
        }

:::: 这为什么对呢?考虑 p 中两个相邻元素 x, y,若 xy 的祖先,则由于其 \text{dfn} 相邻,所以 xy 之间没有其它关键点,故 x 需要向 y 连边。否则,同样地,\text{lca}(x, y)y 之间也没有其它关键点,于是需要和 y 连边。

这同时也证明了虚树的点数上界为 2k-1

和单调栈建虚树的对比:

  1. 要排序两次,求至多 2k-1\text{lca},故常数是单调栈建虚树的两倍。
  2. 实现较简单,不容易写挂。

提供一个树剖求 \text{lca} 的实现:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
// typedef __int128 i128;
typedef pair<int, int> pii;
const int N = 2.5e5 + 10, mod = 998244353;
template<typename T>
void dbg(const T &t) { cout << t << endl; }
template<typename Type, typename... Types>
void dbg(const Type& arg, const Types&... args) {
    cout << arg << ' ';
    dbg(args...);
}
namespace Loop1st {
int n, m, a[N], b[N], p[N << 1], sz[N], fa[N], dep[N], son[N], top[N], dfn[N], idx, len;
ll c[N], dp[N];
vector<pii>g[N];
vector<int>e[N];
void dfs1(int u) {
    sz[u] = 1;
    for (auto [v, w] : g[u]) if (v != fa[u]) {
        c[v] = min(c[u], (ll)w);
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int tp) {
    dfn[u] = ++idx;
    top[u] = tp;
    if (son[u]) dfs2(son[u], tp);
    for (auto [v, w] : g[u]) if (v != fa[u] && v != son[u]) dfs2(v, v);
}
int lca(int u, int v) {
    int x = top[u], y = top[v];
    while (x != y) {
        if (dep[x] < dep[y]) swap(x, y), swap(u, v);
        u = fa[x]; x = top[u];
    }
    return dfn[u] < dfn[v] ? u : v;
}
void dfs3(int u) {
    dp[u] = 0;
    ll sum = 0;
    for (int v : e[u]) {
        dfs3(v);
        sum += dp[v];
    }
    if (b[u]) dp[u] = c[u];
    else dp[u] = min(sum, c[u]);
    e[u].clear();
}
void main() {
    cin >> n;
    for (int i = 1, u, v, w; i < n; i++) {
        cin >> u >> v >> w;
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
    }
    c[1] = 1ll << 60;
    dfs1(1);
    dfs2(1, 1);
    cin >> m;
    while (m--) {
        int k;
        cin >> k;
        for (int i = 1; i <= k; i++) cin >> a[i], b[a[i]] = 1;
        sort(a + 1, a + k + 1, [&](const int &u, const int &v) { return dfn[u] < dfn[v]; });
        len = 0;
        for (int i = 1; i < k; i++) {
            p[++len] = a[i];
            p[++len] = lca(a[i], a[i + 1]);
        }
        p[++len] = a[k];
        sort(p + 1, p + len + 1, [&](const int &u, const int &v) { return dfn[u] < dfn[v]; });
        len = unique(p + 1, p + len + 1) - p - 1;
        for (int i = 1; i < len; i++) {
            int l = lca(p[i], p[i + 1]);
            e[l].push_back(p[i + 1]);
        }
        dfs3(p[1]);
        cout << dp[p[1]] << '\n';
        for (int i = 1; i <= k; i++) b[a[i]] = 0;
    }
}

}
int main() {
    // freopen("data.in", "r", stdin);
    // freopen("data.out", "w", stdout);
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    int T = 1;
    // cin >> T;
    while (T--) Loop1st::main();
    return 0;
}