一种简单的 O(nlogn) 预处理 O(1) 查询静态树上链半群的方法

· · 个人记录

前置知识。不看也无所谓,会提。

睡觉的时候脑子里蹦出来的,记录一下。

我们首先考虑链上 O(n\log n) 预处理 O(1) 查询怎么做。这个问题最简单的做法是猫树。把猫树直接往树上搬可以得到一个点分树的算法,但那个东西比较麻烦。

我们考虑使用朴素树状数组解决区间查询问题。我们不难发现我们会从 r 开始,不断跳 r&~-r(人话就是去掉最后一位二进制)。直到到某个位置,再跳就要低于 l 了,于是就卡在这了。我们不难发现我们可以通过简单位运算求出 r 最后卡在了哪里。

那么在卡住之前,对于相同的 r,它的行动轨迹是固定的,至多只有 O(\log n) 种(区分依据是跳了多少步),这个东西显然可以预处理;在卡住之后,我们会发现 l 一定在以 r(卡住时)为右端点的长度为 r&-r 的区间内(不然显然可以继续跳)。然后对于每个点这个区间的长度总和显然是 O(n\log n) 的,拆个贡献就能证。然后我们预处理就做完了。

这个做法对比猫树的优势在于,它从头到尾都是从右往左跳而没有回头,于是这是一个天然适合上树的结构。我们考虑直接把这个东西搬到树上,会发现在 r(方便起见就不换字母了,实际上树上的节点和深度是要区分的)卡住之前的部分一模一样,可以直接抄过来。但 r 卡住之后就遇到了一个问题:每个点深度的二进制最低位权,这东西的和不再是 O(n\log n) 的了。只要能解决这个问题,我们的算法就能上树。

然后如果读者看了那个前置知识就可以把这一段跳了。简单来说,不难注意到我们的深度其实没有意义,真正有意义的是深度差(或者叫“相对深度”)。这也就是说我们可以把每个点的深度同加一个数。我们逐二进制决定这个加的深度。对于从低到高第 i 个二进制位,假设尚未决定二进制最低位权且在这一位上是 0 的节点有 x 个,是 1 的有 y 个,那我们通过抉择这一位加 1 还是加 0 可以让 \max\{x,y\} 个点止步于此,最低位权定在 2^i。对于剩下 \min\{x,y\} 个点它们的位权先贡献掉 2^i,剩下的部分可以迭代解决。这个东西的复杂度一言以蔽之就是 T(n)=2T(\min\{x,y\})+n,是 O(n\log n) 的。

于是我们就做完了。这个东西看着就一脸好写,实际上也就是非常好写以至于 LCA 可以占到算法其他核心部分的 50\% 长度左右。

挂个P8820 [CSP-S 2022] 数据传输代码吧:

#include <bits/stdc++.h>
using namespace std;
constexpr int N = 2e5 + 9, H = 18;
typedef long long ll;
int n, m, k;
inline void cmin(auto& x, auto y) { x > y && (x = y); }
struct mat {
  ll a[3][3];
  ll* operator[](int x) { return a[x]; }
  mat() { memset(a, 0x3f, sizeof a); }
  explicit mat(ll x, ll y) : mat() {
    for (int i = 0; i < k; ++i) a[i][0] = x;
    for (int i = 1; i < k; ++i) a[i - 1][i] = 0;
    if (k == 3) a[1][1] = y;
  }
};
mat op(mat a, mat b) {
  mat c;
  for (int i = 0; i < k; ++i)
    for (int j = 0; j < k; ++j)
      for (int k = 0; k < ::k; ++k) cmin(c[i][j], a[i][k] + b[k][j]);
  return c;
}
mat e() {
  mat c;
  for (int i = 0; i < k; ++i) c[i][i] = 0;
  return c;
}
typedef pair<mat, mat> pmt;
pmt op_p(pmt a, pmt b) {
  return {op(a.first, b.first), op(b.second, a.second)};
}
pmt e_p() { return {e(), e()}; }
vector<int> es[N], ft[N];
int v[N], mn[N], fa[N], d[N], c[N], dfn[N], t, st[H][N];
mat a[N];
vector<pmt> sum[N], sm[N];
void dfs1(int x) {
  ++c[d[x] = d[fa[x]] + 1], dfn[x] = ++t;
  for (int y : es[x]) erase(es[y], fa[y] = x), dfs1(y);
}
void dfs2(int x) {
  int y = x;
  pmt s = e_p();
  for (int i = 0, w = d[x] & -d[x]; i < w && y; ++i)
    sum[x].push_back(s = op_p(s, pair{a[y], a[y]})), y = fa[y];
  if (sum[x].shrink_to_fit(), y) {
    sm[x].push_back(s), ft[x].push_back(y);
    while (!ft[y].empty()) {
      sm[x].push_back(s = op_p(s, sum[y].back()));
      ft[x].push_back(y = ft[y][0]);
    }
    sm[x].shrink_to_fit(), ft[x].shrink_to_fit();
  }
  for (int y : es[x]) dfs2(y);
}
void build() {
  dfs1(1);
  for (int i = 1; i <= n; i <<= 1) {
    int s = 0;
    for (int j = (i << 1) - *d; j <= n; j += i << 1) s += c[j];
    for (int j = i - *d; j <= n; j += i << 1) s -= c[j];
    if (s > 0) *d |= i;
  }
  for (int i = 1; i <= n; ++i) d[i] += *d;
  dfs2(1);
  for (int i = 1; i <= n; ++i) st[0][dfn[i]] = d[fa[i]];
  for (int i = 0; i + 1 < H; ++i)
    for (int j = 1; j + (1 << (i + 1)) <= n + 1; ++j)
      st[i + 1][j] = min(st[i][j], st[i][j + (1 << i)]);
}
int lca(int x, int y) {
  if (x == y) return d[x];
  if ((x = dfn[x]) > (y = dfn[y])) swap(x, y);
  int h = __builtin_ia32_bsrsi(++y - ++x);
  return min(st[h][x], st[h][y - (1 << h)]);
}
pmt qlink(int x, int l) {
  int r = d[x], h = 1 << __builtin_ia32_bsrsi(l ^ r);
  pmt res = e_p();
  if (int k = __builtin_popcount(r & ~-h))
    res = sm[x][--k], x = ft[x][k], r &= -h;
  if (l != r) res = op_p(res, sum[x][r - l - 1]);
  return res;
}
mat query(int u, int v) {
  int z = lca(u, v);
  mat res = qlink(u, z - 1).first;
  if (d[v] != z) res = op(res, qlink(v, z).second);
  return res;
}
signed main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin >> n >> m >> k;
  memset(mn + 1, 0x3f, n * sizeof(int));
  for (int i = 1; i <= n; ++i) cin >> v[i];
  for (int i = 1, x, y; i < n; ++i) {
    cin >> x >> y, cmin(mn[x], v[y]), cmin(mn[y], v[x]);
    es[x].push_back(y), es[y].push_back(x);
  }
  for (int i = 1; i <= n; ++i) a[i] = mat(v[i], mn[i]);
  build();
  for (int x, y; m; --m) cin >> x >> y, cout << query(x, y)[k - 1][0] << '\n';
  return cout << flush, 0;
}

如果有 WRM 洁癖的话可以把那个 popcount 换成预处理或者干脆用最低位而不是 1 的数量来指示,都一样。

bsrsi 那个东西等价于 __lg。个人癖好。

以上。