一个新做法
Part0
模拟赛 vp 的时候,写了个
Part1 题目大意
给你一颗
Part2 转化
考虑将图反向,容易发现答案不变。
题目转化为给定叶向树,和若干条后代到祖先的有向边,求两两最短路之和。
Part3 思路
考虑对于每个点分别计算答案。
容易得出结论:对于任意一个点
反证:如果经过了第二类边,那一定不优。
并且,每一条第二类边只会从后代连向祖先,所以又得到结论:从
于是,可以处理出
考虑怎么处理
考虑 dfs,从
具体的,先加入
而加边操作可以做到
那么,时间复杂度至少是
Part4 优化
可能需要卡卡常,但是实际上跑不满
Part5 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int maxN = 18, maxV = (1<<maxN)+11;
const lld inf64 = 0x3f3f3f3f3f3f3f3f;
const lld mod = 998244353;
int n, m, a[maxV], V; vector<pair<int,int>> nxt[maxV];
lld ans;
lld dis[maxV]; vector<pair<lld,int>> tof[maxV];
int dep[maxV], sz[maxV];
lld f[maxV];
#define ls (u<<1)
#define rs (ls|1)
struct floyed
{
lld dis[19][19];
void push_back(int v, int w)
{
for (int i = 1; i < v; ++i)
dis[i][v] = dis[i][v - 1] + w;
}
void add_edge(int v, int t, lld w)
{
// fprintf(stderr, "add %d %d %lld\n", v, t, w);
for (int i = 2; i <= v; ++i)
for (int j = 1; j < i; ++j)
{
dis[i][j] = min(dis[i][j], dis[i][v] + w + dis[t][j]);
}
}
void reset()
{
memset(dis, 0x3f, sizeof(dis));
for (int i = 1; i <= 18; ++i)
dis[i][i] = 0;
}
};
void DFS(int u)
{
sz[u] = 1;
if (dep[u] < n)
{
dis[ls] = dis[u] + a[ls];
dis[rs] = dis[u] + a[rs];
dep[ls] = dep[u] + 1;
dep[rs] = dep[u] + 1;
DFS(ls);
DFS(rs);
sz[u] += sz[ls] + sz[rs];
f[u] = (f[ls] + lld(sz[ls]) * a[ls] + f[rs] + lld(sz[rs]) * a[rs]) % mod;
tof[u] = tof[ls];
for (auto p : tof[rs])
tof[u].push_back(p);
}
for (auto p : nxt[u])
tof[u].push_back({dis[u] + p.second, p.first});
}
void DFS2(int u, floyed ac)
{
if (u > 1)
{
ac.push_back(dep[u], a[u]);
for (auto p : tof[u])
{
lld w = p.first - dis[u];
int to = p.second;
ac.add_edge(dep[u], dep[to], w);
}
}
// fprintf(stderr, "node %d :\n", u);
// for (int i = 1; i <= dep[u]; ++i)
// fprintf(stderr, "\t%lld", ac.dis[dep[u]][i]);
// fprintf(stderr, "\n");
int pre = 0;
int x = u;
while (x)
{
if (ac.dis[dep[u]][dep[x]] >= inf64)
break;
int s = sz[x] - sz[pre];
// if (u == 2 && x == 1)
// fprintf(stderr, "*s = %d\n", s);
(ans += ac.dis[dep[u]][dep[x]] % mod * s + f[x] - f[pre] - lld(sz[pre]) * a[pre]) %= mod;
pre = x;
x >>= 1;
}
if (dep[u] < n)
{
DFS2(ls, ac);
DFS2(rs, ac);
}
}
int main()
{
scanf("%d%d", &n, &m);
V = (1<<n)-1;
// fprintf(stderr, "V = %d\n", V);
for (int i = 2; i <= V; ++i)
scanf("%d", a + i);
for (int i = 1; i <= m; ++i)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
nxt[v].push_back({u, w});
}
dep[1] = 1;
DFS(1);
// fprintf(stderr, "dep[1] = %d\n", dep[1]);
// for (int i = 1; i <= V; ++i)
// fprintf(stderr, "%lld%c", f[i], " \n"[i == V]);
floyed gen;
// memset(gen.dis, 0x3f, sizeof(gen.dis));
gen.reset();
DFS2(1, gen);
printf("%lld\n", (ans % mod + mod) % mod);
fclose(stdin);
fclose(stdout);
return 0;
}