一个新做法

· · 题解

Part0

模拟赛 vp 的时候,写了个 \mathcal{O}(2^n n^3),结果过了,发现题解区做法和我都不太一样,写一发题解。

Part1 题目大意

给你一颗 2^n-1 个节点的根向满二叉树,再给出若干条从祖先到后代的有向边,所有边都有边权,求两两最短路之和,取摸 998244353

Part2 转化

考虑将图反向,容易发现答案不变。

题目转化为给定叶向树,和若干条后代到祖先的有向边,求两两最短路之和。

Part3 思路

考虑对于每个点分别计算答案。

容易得出结论:对于任意一个点 u,它到它子树中任意一个点的最短路,都不会经过第二类边。

反证:如果经过了第二类边,那一定不优。

并且,每一条第二类边只会从后代连向祖先,所以又得到结论:从 u 出发,到达不在 u 子树内的点 v 的路径,一定经过 \mathrm{LCA}(u, v)

于是,可以处理出 u 到达它每个祖先的最短路,就可以算出 u 到达所有点的最短路之和。

考虑怎么处理 u 到所有祖先的最短路。发现树的深度很小,是 \log(\text{点数}) 级别。因此,可以暴力存储 u 到每个祖先的最短距离。

考虑 dfs,从 uu 的儿子 v 递归时,要新加入 v 子树连向 u 祖先的边。可以在 dfs 时记录 u 祖先的全源最短路,加入 v 点子树内的边,相当于在图上新增了若干条边。

具体的,先加入 u \rightarrow v 边权为 a[v] 的边,再对于 v 子树内每条非树边 x \rightarrow y 边权为 z,令 vx 距离为 d,连边 v \rightarrow y,边权为 d + z

而加边操作可以做到 \mathcal{O}(n^2),所以这个算法的复杂度为:

\sum_{i = 1}^{2^n-1}\mathrm{sz}_i\mathrm{dep}_i^2

那么,时间复杂度至少是 \mathcal{O}(n^3 2^n) 的。

Part4 优化

可能需要卡卡常,但是实际上跑不满 n^3,应该会有一个比较客观的常数,能够通过。

Part5 代码

#include <bits/stdc++.h>
using namespace std;

typedef long long lld;

const int maxN = 18, maxV = (1<<maxN)+11;
const lld inf64 = 0x3f3f3f3f3f3f3f3f;
const lld mod = 998244353;
int n, m, a[maxV], V; vector<pair<int,int>> nxt[maxV];
lld ans;
lld dis[maxV]; vector<pair<lld,int>> tof[maxV];
int dep[maxV], sz[maxV];
lld f[maxV];

#define ls (u<<1)
#define rs (ls|1)

struct floyed
{
    lld dis[19][19];
    void push_back(int v, int w)
    {
        for (int i = 1; i < v; ++i)
            dis[i][v] = dis[i][v - 1] + w;
    }
    void add_edge(int v, int t, lld w)
    {
        // fprintf(stderr, "add %d %d %lld\n", v, t, w);
        for (int i = 2; i <= v; ++i)
            for (int j = 1; j < i; ++j)
            {
                dis[i][j] = min(dis[i][j], dis[i][v] + w + dis[t][j]);
            }
    }
    void reset()
    {
        memset(dis, 0x3f, sizeof(dis));
        for (int i = 1; i <= 18; ++i)
            dis[i][i] = 0;
    }
};

void DFS(int u)
{
    sz[u] = 1;
    if (dep[u] < n)
    {
        dis[ls] = dis[u] + a[ls];
        dis[rs] = dis[u] + a[rs];
        dep[ls] = dep[u] + 1;
        dep[rs] = dep[u] + 1;
        DFS(ls);
        DFS(rs);
        sz[u] += sz[ls] + sz[rs];
        f[u] = (f[ls] + lld(sz[ls]) * a[ls] + f[rs] + lld(sz[rs]) * a[rs]) % mod;
        tof[u] = tof[ls];
        for (auto p : tof[rs])
            tof[u].push_back(p);
    }
    for (auto p : nxt[u])
        tof[u].push_back({dis[u] + p.second, p.first});
}

void DFS2(int u, floyed ac)
{
    if (u > 1)
    {
        ac.push_back(dep[u], a[u]);
        for (auto p : tof[u])
        {
            lld w = p.first - dis[u];
            int to = p.second;
            ac.add_edge(dep[u], dep[to], w);
        }
    }
    // fprintf(stderr, "node %d :\n", u);
    // for (int i = 1; i <= dep[u]; ++i)
    //     fprintf(stderr, "\t%lld", ac.dis[dep[u]][i]);
    // fprintf(stderr, "\n");
    int pre = 0;
    int x = u;
    while (x)
    {
        if (ac.dis[dep[u]][dep[x]] >= inf64)
            break;
        int s = sz[x] - sz[pre];
        // if (u == 2 && x == 1)
        //     fprintf(stderr, "*s = %d\n", s);
        (ans += ac.dis[dep[u]][dep[x]] % mod * s + f[x] - f[pre] - lld(sz[pre]) * a[pre]) %= mod;
        pre = x;
        x >>= 1;
    }
    if (dep[u] < n)
    {
        DFS2(ls, ac);
        DFS2(rs, ac);
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    V = (1<<n)-1;
    // fprintf(stderr, "V = %d\n", V);
    for (int i = 2; i <= V; ++i)
        scanf("%d", a + i);
    for (int i = 1; i <= m; ++i)
    {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        nxt[v].push_back({u, w});
    }
    dep[1] = 1;
    DFS(1);
    // fprintf(stderr, "dep[1] = %d\n", dep[1]);
    // for (int i = 1; i <= V; ++i)
    //     fprintf(stderr, "%lld%c", f[i], " \n"[i == V]);
    floyed gen;
    // memset(gen.dis, 0x3f, sizeof(gen.dis));
    gen.reset();
    DFS2(1, gen);
    printf("%lld\n", (ans % mod + mod) % mod);
    fclose(stdin);
    fclose(stdout);
    return 0;
}