题解:P14658 积云四月

· · 题解

定义域空间 S_1\times\cdots\times S_m 和距离函数都是连续的,所以函数值域也是连通的,我们只需要分别求出最小值和最大值即可。

先求最大值。首先证明,对于关于 x 的函数 D(a,x)+D(x,b)x\in S_i),最大值一定可以在 S_i 的叶子处取到。

:::info[证明]

对于一条路径 \operatorname{path}(u,v),将路径上的点 x 按照 D(u,x) 对应到 [0,D(u,v)] 内的实数 t。那么对于一个固定的点 p,不难发现关于 t 的函数 D(p,g(t)) 是一个下凸函数,其中 g(t) 表示实数 t 的对应点。现在考察一个非叶子节点 x,必然存在两个叶子 u,v 使得 x\in\operatorname{path}(u,v),那么显然对于这条路径来说,D(a,g(t))+D(g(t),b) 也是下凸函数,因此两个叶子中至少有一个取值比 x 更优。\Box

:::

有了这个性质,这部分就很简单了。直接 DP,令 f_i(u) 表示考虑 S_{1\sim i}x_i=u 时距离和的最大值。转移就是

f_i(u)=\max_{v\in S_{i-1}}\{f_{i-1}(v)+D(u,v)\}

S_{i-1}\cup S_i 建虚树,给 S_{i-1} 内的点赋点权,跑换根 DP 即可。

再来求最小值。这里只取虚树内的点跑 DP 就不对了,我们需要一些别的做法。

从链入手,每个 S_i 都是一个区间。考虑这样的做法:维护当前答案 ans 和一个区间 I,表示对于 S_{1\sim i}x_i\in I 时距离和可以取到最小值。初始时 ans=0I=S_1。从小到大枚举 i=2\sim n,若 I\cap S_i\neq \varnothing,则令 I\gets I\cap S_i;否则令 ans 加上 IS_i 间的最短距离,把 I 置为里 S_i 最近的那个端点。

自然地推广到树上:猜测 f_i(u) 一定可以被表示成 C_i+D(u,T_i) 的形式,其中 C_i 是常数,T_i\subseteq S_i 是使得距离和最小的 x_i 构成的连通块。初始时 f_1(u)=0+D(u,S_1)。对于 i,有转移

f_i(u)=\min_{v\in S_{i-1}}\{f_{i-1}(v)+D(u,v)\}=C_{i-1}+\min_{v\in S_{i-1}}\{D(v,T_{i-1})+D(u,v)\}

显然

D(v,T_{i-1})+D(u,v)\geq D(u,T_{i-1})

vuT_{i-1} 上的投影点即可取等,因此转移式简化为

f_i(u)=C_{i-1}+D(u,T_{i-1})

T_{i-1}\cap S_i\neq\varnothing,则 C_i=C_{i-1}T_i=T_{i-1}\cap S_i

T_{i-1}\cap S_i=\varnothing,则 C_i=C_{i-1}+D(T_{i-1},S_i)T_iS_i 内离 T_{i-1} 最近的单点。

我们需要支持:连通块求交,连通块求距离最小值和距离最小对应的单点。

不妨先判断是否有交,我们只要能判断某个点 u 是否在某个点集 S 张成的虚树内即可。判断条件是简单的:

T_{i-1} 的叶子集合为 L_TS_i 的叶子集合为 L_i。对 L_T\cup L_i 建出虚树,DP 出每个点内 L_T/L_i 的叶子个数,这样就可以 \mathcal{O}(1) 判断一个点是否在某个虚树内了。

若有交,则深度最小的交点为交集的根,向下沿着公共边即可得到整个交集。为了保证时间复杂度正确,我们只需要保留叶子节点和根来表示整个连通块。

若无交,我们同样对 L_T\cup L_i 建出虚树,以 T_{i-1} 内的点为源点,换根 DP 求出多源最短路(和前面求最大值基本一致),枚举 S_i 内的点求出距离最小值和对应点即可。

这里需要证明,距离 T_{i-1} 最近的点一定是虚树上的点。

:::info[证明]

反证。若存在某条虚树边 (x,y) 内部的点 u 距离 T_{i-1} 最近,那么 u 到其在 T_{i-1} 上的投影点的路径必然不会经过 x,y,那么 u 的度数至少为 3,矛盾。\Box

:::

求点集的根可以按照 DFS 序排序,求第一个点和最后一个点的 LCA。求叶子同样先排序,判断当前点是否是后一个点的祖先即可。

使用 \mathcal{O}(n\log{n})-\mathcal{O}(1) LCA,时间复杂度为 \mathcal{O}(n\log{n}+\sum k_i\log{k_i})。瓶颈在于 RMQ 和排序,理论可以做到线性。

代码写了一坨,仅供参考。

:::success[主要代码]

int n, m, k[MAXN];
int timer, dfn[MAXN], edfn[MAXN], dep[MAXN];
int top, stk[MAXN];
bool vis[MAXN], visA[MAXN], inA[MAXN], inB[MAXN];
int mxSon[MAXN], mnSon[MAXN];
ll dis[MAXN];
i128 mnv, dp[MAXN], ndp[MAXN], mx[MAXN][2];
ll dp2[MAXN], mn[MAXN][2];
int cntA[MAXN], cntB[MAXN];
vector<int> points, S;
int p[MAXN], L[MAXN], stP[MAXN], stL[MAXN];
vector<pii> T[MAXN];
vector<pair<int, ll>> VT[MAXN];

struct ST {
    int f[LOGN][MAXN];
    int get(int x, int y) const { return dfn[x] < dfn[y] ? x : y; }
    void init() {
        for (int i = 1; (1 << i) <= n; ++i)
            for (int j = 1; j <= n - (1 << i) + 1; ++j)
                f[i][j] = get(f[i - 1][j], f[i - 1][j + (1 << i - 1)]);
    }
    int query(int l, int r) {
        int k = lg2(r - l + 1);
        return get(f[k][l], f[k][r - (1 << k) + 1]);
    }
} st;

int lca(int x, int y) {
    if (x == y) return x;
    if (dfn[x] > dfn[y]) swap(x, y);
    return st.query(dfn[x] + 1, dfn[y]);
}

ll dist(int x, int y) {
    return dis[x] + dis[y] - dis[lca(x, y)] * 2;
}

void dfs1(int u, int faU) {
    dfn[u] = ++timer;
    st.f[0][timer] = faU;
    for (auto [v, w] : T[u]) {
        if (v == faU) continue;
        dep[v] = dep[u] + 1;
        dis[v] = dis[u] + w;
        dfs1(v, u);
    }
    edfn[u] = timer;
}

void dfs2(int u) {
    ndp[u] = visA[u] ? dp[u] : -inf1;
    mx[u][0] = mx[u][1] = -inf1;
    mxSon[u] = 0;
    for (auto [v, w] : VT[u]) {
        dfs2(v);
        i128 val = ndp[v] + w;
        if (val >= mx[u][0]) {
            mx[u][1] = mx[u][0];
            mx[u][0] = val;
            mxSon[u] = v;
        } else {
            chkMax(mx[u][1], val);
        }
    }
    chkMax(ndp[u], mx[u][0]);
}

void dfs3(int u, i128 val) {
    chkMax(ndp[u], val);
    for (auto [v, w] : VT[u])
        dfs3(v, max({visA[u] ? dp[u] : -inf1, mx[u][mxSon[u] == v], val}) + w);
}

void dfs4(int u) {
    for (auto [v, w] : VT[u]) {
        dfs4(v);
        cntA[u] += cntA[v];
        cntB[u] += cntB[v];
    }
}

void dfs5(int u) {
    bool leaf = true;
    for (auto [v, w] : VT[u]) {
        if (inA[v] && inB[v]) {
            leaf = false;
            dfs5(v);
        }
    }
    if (leaf && S.back() != u) S.emplace_back(u);
}

void dfs6(int u) {
    dp2[u] = visA[u] ? 0 : inf2;
    mn[u][0] = mn[u][1] = inf2;
    mnSon[u] = 0;
    for (auto [v, w] : VT[u]) {
        dfs6(v);
        ll val = dp2[v] + w;
        if (val <= mn[u][0]) {
            mn[u][1] = mn[u][0];
            mn[u][0] = val;
            mnSon[u] = v;
        } else {
            chkMin(mn[u][1], val);
        }
    }
    chkMin(dp2[u], mn[u][0]);
}

void dfs7(int u, ll val) {
    chkMin(dp2[u], val);
    for (auto [v, w] : VT[u])
        dfs7(v, min({visA[u] ? 0 : inf2, mn[u][mnSon[u] == v], val}) + w);
}

int build(const vector<int> &vec) {
    auto addPoint = [&](int u) {
        if (vis[u]) return;
        vis[u] = true;
        points.emplace_back(u);
    };
    auto addEdge = [&](int u, int v) {
        VT[u].emplace_back(v, dist(u, v));
        addPoint(u);
        addPoint(v);
    };
    stk[top = 1] = vec[0];
    addPoint(vec[0]);
    int sz = vec.size();
    for (int i = 1; i < sz; ++i) {
        int u = vec[i];
        int d = lca(stk[top], u);
        while (top >= 2 && dep[stk[top - 1]] >= dep[d]) {
            addEdge(stk[top - 1], stk[top]);
            --top;
        }
        if (d != stk[top]) {
            addEdge(d, stk[top]);
            stk[top] = d;
        }
        stk[++top] = u;
    }
    for (int i = top - 1; i; --i) addEdge(stk[i], stk[i + 1]);
    return stk[1];
}

ostream &operator<<(ostream &ot, i128 x) {
    static char tmp[40], *pt;
    pt = tmp;
    while (*pt++ = (x % 10) ^ 48, x /= 10);
    while (pt-- != tmp) ot << *pt;
    return ot;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    for (int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        T[u].emplace_back(v, w);
        T[v].emplace_back(u, w);
    }
    dfs1(1, 0);
    st.init();
    auto cmp = [&](int x, int y) {
        return dfn[x] < dfn[y];
    };
    cin >> m;
    for (int i = 1, curP = 1, curL = 1; i <= m; ++i) {
        cin >> k[i];
        stP[i] = curP;
        stL[i] = curL;
        for (int j = curP; j <= curP + k[i] - 1; ++j) cin >> p[j];
        sort(p + curP, p + curP + k[i], cmp);
        k[i] = unique(p + curP, p + curP + k[i]) - (p + curP);
        for (int j = curP; j < curP + k[i] - 1; ++j) {
            int u = p[j], v = p[j + 1];
            if (dfn[v] > edfn[u]) L[curL++] = u;
        }
        curP += k[i];
        L[curL++] = p[curP - 1];
        if (i == m) {
            stP[m + 1] = curP;
            stL[m + 1] = curL;
        }
    }
    S = vector<int>(p + 1, p + stP[2]);
    for (int i = 2; i <= m; ++i) {
        auto clr = [&]() {
            for (int u : points) {
                vis[u] = false;
                VT[u].clear();
            }
            points.clear();
        };

        vector<int> vec;
        vec.resize(k[i - 1] + k[i]);
        merge(p + stP[i - 1], p + stP[i], p + stP[i], p + stP[i + 1], vec.begin(), cmp);
        vec.erase(unique(vec.begin(), vec.end()), vec.end());
        int rt = build(vec);
        for (int j = stP[i - 1]; j < stP[i]; ++j) visA[p[j]] = true;
        dfs2(rt);
        dfs3(rt, -inf1);
        for (int j = stP[i]; j < stP[i + 1]; ++j) dp[p[j]] = ndp[p[j]];
        for (int j = stP[i - 1]; j < stP[i]; ++j) visA[p[j]] = false;
        clr();
        vec.resize(S.size() + k[i]);
        merge(S.begin(), S.end(), p + stP[i], p + stP[i + 1], vec.begin(), cmp);
        vec.erase(unique(vec.begin(), vec.end()), vec.end());
        rt = build(vec);
        for (int j = 0; j < S.size() - 1; ++j) {
            int u = S[j], v = S[j + 1];
            if (dfn[v] > edfn[u]) ++cntA[u];
        }
        ++cntA[S.back()];
        for (int j = stL[i]; j < stL[i + 1]; ++j) ++cntB[L[j]];
        int rtA = lca(S.front(), S.back()), rtB = lca(p[stP[i]], p[stP[i + 1] - 1]);
        dfs4(rt);
        bool inter = false;
        int rtInter = 0;
        for (int u : points) {
            inA[u] = dfn[u] >= dfn[rtA] && dfn[u] <= edfn[rtA] && cntA[u];
            inB[u] = dfn[u] >= dfn[rtB] && dfn[u] <= edfn[rtB] && cntB[u];
            if (inA[u] && inB[u]) {
                inter = true;
                if (!rtInter || dep[u] < dep[rtInter]) rtInter = u;
            }
        }
        if (inter) {
            S.clear();
            S.emplace_back(rtInter);
            dfs5(rtInter);
        } else {
            for (int u : points)
                if (inA[u])
                    visA[u] = true;
            dfs6(rt);
            dfs7(rt, inf2);
            ll mnd = inf2;
            int mnp = 0;
            for (int u : points) {
                if (inB[u] && dp2[u] < mnd) {
                    mnd = dp2[u];
                    mnp = u;
                }
            }
            mnv += mnd;
            S.clear();
            S.emplace_back(mnp);
            for (int u : points)
                if (inA[u])
                    visA[u] = false;
        }
        for (int u : points) cntA[u] = cntB[u] = 0;
        clr();
    }
    i128 mxv = -inf1;
    for (int i = stP[m]; i < stP[m + 1]; ++i) chkMax(mxv, dp[p[i]]);
    cout << mxv - mnv + 1;
    return 0;
}

:::