NOI2023 D2T1

· · 题解

树上没有横叉边,因此任意一条 u \rightsquigarrow v 的路径必经 \operatorname{LCA}(u, v).因此,d(u, v) = d(u, \operatorname{LCA}(u, v)) + d(\operatorname{LCA}(u, v), v)

前半部分只能是 u 上行到 \operatorname{LCA}(u, v),因此前半部分的答案是平凡的带权深度差.而后半部分 d(\operatorname{LCA}(u, v), v) 则是我们要处理的对象.观察这个形式,整个题目被我们简化为计算所有 祖先通向后代 的最短距离.至于 \operatorname{LCA}(u, v) = v 则是平凡的 corner case.

为什么这是一种简化?因为整个图是满二叉树,所以层数 n 很小.因此,这样的祖先-后代对数量级是 \Theta(n 2^n) 的,不可做的点对数量变得可做.

下面设 u 通向 v 的儿子是 o.从 u 下行到 v 的最短路,只有 u 的祖先和 T(o) 中的点可能被涉及.原因是假如我们走到了一个点 x,它既不是 u 的祖先,也不在 T(o) 中,则 u \rightsquigarrow xx \rightsquigarrow v 两条路径一定是有除了 x 以外的交点的.这就走出了一个环,违背正权图中最短路无环的特征.

因此,计算 d(u, v), v \in T(u) 时,只需考虑 u 的祖先和 T(o) 中的点.而这些点与 u 又构成祖先-后代关系,因此全局地考虑,对于所有 d(u, v \in T(u)) 的计算,总共涉及到的点的数量与祖先-后代点对的数量级同阶.这也就意味着我们可以直接对每个 u,将需要的节点直接导出,做 dijkstra 求出所有的 d(u, v), v \in T(u).总点数量级 \Theta(n 2^n),因此复杂度 \Theta(n2^n \times \log 2^n) = \Theta(n^2 2^n)

为了能够精确扫描到需要的边和节点,计算 d(u, v \in T(u)) 时,可以考虑每次仅加入终点在 T(u) 中的 第一类边和第二类边 + u 的祖先上所有的 第一类边.这里其实又暗含了一个性质是 u \rightsquigarrow v 的过程可能会走到 u 的祖先,但一定会走一次第二类边直接到 T(o) 中(而不会走到 u 的祖先上).原因也是最短路上无正环.

现在理论上所有的 d(u, v \in T(u)) 已可求得,整个题目所需要的全部思维量已经结束,接下来只是一些统计上的完善简化工作,这里略去,具体可以看代码.

#include <bits/stdc++.h>
#define int long long
inline int read() {
    int x = 0; bool f = true; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = false;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
    return f ? x : (~(x - 1));
}
typedef std :: pair <int, int> pii;

const int N = 1 << 20;
const int mod = 998244353;
int a[N], f[N], dis[N], n, m;
std :: vector <pii> G[N], R[N];
std :: vector <int> T[N];

int solve(int s, int o) {
    std :: priority_queue <pii> q;
    for (int u = (s << 1 | o); u; u >>= 1) {
        dis[u] = (int)1e17;
        G[u].emplace_back(u >> 1, a[u]);
    }
    for (int u : T[s << 1 | o]) {
        dis[u] = (int)1e17;
        for (pii e : R[u]) G[e.first].emplace_back(u, e.second);
    }
    dis[s] = 0; q.emplace(0, s);
    while (!q.empty()) {
        int d = q.top().first, u = q.top().second;
        q.pop();
        if (d + dis[u]) continue;
        for (pii e: G[u]) {
            int v = e.first, w = e.second;
            if (dis[v] > dis[u] + w) {
                dis[v] = dis[u] + w;
                q.emplace(-dis[v], v);
            }
        }
    }
    int ans = 0;
    for (int u : T[s << 1 | o]) if (dis[u] < (int)1e17)
        (ans += (dis[u] + a[s << 1 | (o ^ 1)]) * (int)T[s << 1].size() % mod + f[s << 1 | (o ^ 1)] + dis[u]) %= mod;
    for (int u : T[s << 1 | o]) std :: vector <pii> ().swap(G[u]);
    for (int u = s; u; u >>= 1) std :: vector <pii> ().swap(G[u]);
    return ans;
}

signed main() {
    n = (1 << read()) - 1; m = read();
    for (int u = 2; u <= n; ++u)
        R[u >> 1].emplace_back(u, a[u] = read());
    for (int i = 1; i <= m; ++i) {
        int u = read(), v = read(), w = read();
        R[v].emplace_back(u, w);
    }
    for (int u = 1; u <= n; ++u)
        for (int v = u; v; v >>= 1)
            T[v].push_back(u);
    for (int u = n >> 1; u; --u)
        f[u] = (f[u << 1] + f[u << 1 | 1] + (a[u << 1] + a[u << 1 | 1]) * (int)T[u << 1].size()) % mod;
    int ans = 0;
    for (int u = 1; u <= (n >> 1); ++u)
        (ans += solve(u, 0) + solve(u, 1) + f[u]) %= mod;
    printf("%lld\n", ans);
    return 0;
}