两种压缩 01-Trie 树:压位和压链

· · 算法·理论

虽然“压位 01-Trie”和“压链 01-Trie”的编辑距离仅为 1,但是它们的区别很大。

压位 01-Trie,动态前驱后继

压位 01-Trie 维护一个不可重集合,解决插入一个数,删除一个数,查询一个数的前驱或后继的问题。设值域为 V,则时间复杂度为单次操作 O(\log_w V),代价是空间复杂度 O(V/w),其中 w=64

压位 01-Trie 是 w 叉的 Trie 树,但和常规认识不同的是,它不动态开点,而是直接把整个值域范围开满,然后删掉最后一层的儿子信息。不过好在它的常数足够小,V=2^{30} 也只是需要一百多 MB 的空间。

实现细节:这个 Trie 树一共有五层(\log_w V),层数越低在树上越深。每个结点有一个颜色,如果结点的子树里面有至少一个数,则它是黑的,否则是白的。每一层的所有结点都记录一个 uint64_t,其中的第 i 位表示它的第 i 个儿子的颜色。

技术细节:


constexpr int bitctz(uint64_t x) { return __builtin_ctzll(x); }
constexpr int bitclz(uint64_t x) { return __builtin_clzll(x); }
struct cptrie {
  static constexpr int m = 63, g = 6;
  int L;
  vector<uint64_t> a[5];
  cptrie(int sz) : L(0) {
    while (true) {
      a[L].resize(((sz - 1) >> (L + 1) * g) + 1);
      if (a[L++].size() <= 1) break;
    }
  }
  void insert(uint32_t x) {
    for (int i = 0; i < L; i++) {
      auto mask = 1ull << (x >> i * g & m);
      auto &v = a[i][x >> (i + 1) * g];
      if (v & mask) break; 
      v |= mask;
    }
  }
  void remove(uint32_t x) {
    for (int i = 0; i < L; i++) {
      auto mask = 1ull << (x >> i * g & m);
      auto &v = a[i][x >> (i + 1) * g];
      if (v &= ~mask) break; 
    }
  }
  uint32_t succ(uint32_t x) const {
    for (int i = 0; i < L; i++) {
      int cur = x >> i * g & m;
      auto v = a[i][x >> (i + 1) * g];
      if (v >> cur > 1) { // >> 64 行为未定义
        uint32_t res = x & (-1u << (i + 1) * g);
        res += (bitctz(v >> (cur + 1)) + cur + 1) << i * g;
        for (int j = i; j--; ) res += bitctz(a[j][res >> (j + 1) * g]) << j * g;
        return res;
      }
    }
    return 0;
  }
  uint32_t prev(uint32_t x) const {
    for (int i = 0; i < L; i++) {
      int cur = x >> i * g & m;
      auto v = a[i][x >> (i + 1) * g];
      if (v & ~(-1ull << cur)) {
        uint32_t res = x & (-1u << (i + 1) * g);
        res += (m - bitclz(v & ~(-1ull << cur))) << i * g;
        for (int j = i; j--; ) res += (m - bitclz(a[j][res >> (j + 1) * g])) << j * g;
        return res;
      }
    }
    return 0;
  }
};

压链 01-Trie,线性空间平衡树

压链 01-Trie 就是平衡树,拥有平衡树的几乎一切特征,解决平衡树的问题。设值域为 V,则时间复杂度 O(\log_2 V),空间复杂度 O(n)。劣势是时间复杂度是可以卡满的,有时显得比较危险。优势是继承了 01-Trie 解决异或问题的能力。

压链 01-Trie 的核心出装和它的名字一样,将所有的链压缩成一条边,等价的表述是缩二度点(广义串并联图的 Compress)或者建立所有叶子的虚树。

实现细节:压链 01-Trie 是 Leafy 的平衡树,每个结点要么是叶子,要么有两个儿子(根是一个特例,但是我们可以不用写出根)。每个结点有一个 dep 深度,表示子树内所有树的高 dep 位都是相同的,叶子的 dep=64,结点的父亲的 dep 一定比该节点小。之所以叫深度是因为这是在未压缩的 Trie 树上的深度。另外每个结点记录 val 表示子树内最大值(很类似 WBLT 的实现?)。

我们实现 merge 操作。merge 操作传入两个结点 p,q 要满足它们的父亲的深度是相同的(如果没有父亲,则父亲的深度为零)。分类讨论:

其余操作和正常的平衡树差不多。向下找特定值的时候,根据记录下来的 depval 判断能否沿着链向下(有可能走不下去),具体看代码中的 rank

可持久化合并的版本:(普通平衡树)

constexpr int N = (1.1e6 + 10)*16;
namespace compressed_trie {
  int ch[N << 1][2], dep[N << 1], tot;
  uint64_t val[N << 1];
  int infos[N << 1];
  int newnode(uint64_t v, int info) {
    int p = ++tot;
    val[p] = v, dep[p] = 64;
    infos[p] = info; // 子树大小的意思
    return p;
  }
  template <class Func>
  int merge(int p, int q, Func&& func) {
    if (!p || !q) return p + q;
    if (dep[p] < dep[q]) swap(p, q);
    int z = ++tot;
    if (val[p] == val[q] && dep[q] == 64) {
      val[z] = val[p], dep[z] = dep[p];
      infos[z] = func(infos[p], infos[q]);
      return z;
    }
    int sd = val[p] == val[q] ? 64 : __builtin_clzll(val[p] ^ val[q]);
    if (dep[p] == dep[q] && sd >= dep[q]) {
      dep[z] = dep[p];
      ch[z][0] = merge(ch[p][0], ch[q][0], forward<Func>(func));
      ch[z][1] = merge(ch[p][1], ch[q][1], forward<Func>(func));
    } else if (sd >= dep[q]) {
      memcpy(ch[z], ch[q], sizeof ch[0]), dep[z] = dep[q];
      int r = val[p] >> (63 - dep[z]) & 1;
      ch[z][r] = merge(ch[z][r], p, forward<Func>(func));
    } else {
      dep[z] = sd;
      ch[z][val[p] >> (63 - dep[z]) & 1] = p;
      ch[z][val[q] >> (63 - dep[z]) & 1] = q;
    }
    val[z] = val[ch[z][1]];
    if (dep[z] < 64) assert(ch[z][0] && ch[z][1]), infos[z] = func(infos[ch[z][0]], infos[ch[z][1]]);
    return z;
  }
  int rank(uint64_t v, int p) {
    if (v > val[p]) return infos[p]; // 根的特判
    int res = 0;
    while (dep[p] < 64) {
      if (v <= val[ch[p][0]]) p = ch[p][0];
      else {
        res += infos[ch[p][0]];
        if (val[ch[p][1]] >> (64 - dep[ch[p][1]]) == v >> (64 - dep[ch[p][1]])) p = ch[p][1]; else break;
      }
    }
    return res;
  }
  int kth(int r, int p) {
    while (dep[p] < 64) {
      if (infos[ch[p][0]] < r) r -= infos[ch[p][0]], p = ch[p][1];
      else p = ch[p][0];
    }
    return val[p];
  }
}

没有可持久化合并的版本:(可并堆 2)

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, __VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
constexpr int N = 4e6 + 10;
namespace compressed_trie {
  int ch[N << 1][2], dep[N << 1], tot;
  uint64_t val[N << 1];
  int siz[N << 1];
  int newnode(uint64_t v, int c) {
    v ^= 1ull << 63; // 负数的特判,异或最高位从而使无符号数能正常比较
    int p = ++tot;
    val[p] = v, dep[p] = 64;
    siz[p] = c;
    return p;
  }
  int merge(int p, int q) {
    if (!p || !q) return p + q;
    if (dep[p] < dep[q]) swap(p, q);
//  int z = ++tot;
    int z = q;
    if (val[p] == val[q] && dep[q] == 64) {
      val[z] = val[p], dep[z] = dep[p];
      siz[z] = siz[p] + siz[q];
      return z;
    }
    int sd = val[p] == val[q] ? 64 : __builtin_clzll(val[p] ^ val[q]);
    if (dep[p] == dep[q] && sd >= dep[q]) {
      dep[z] = dep[p];
      ch[z][0] = merge(ch[p][0], ch[q][0]);
      ch[z][1] = merge(ch[p][1], ch[q][1]);
    } else if (sd >= dep[q]) {
      memcpy(ch[z], ch[q], sizeof ch[0]), dep[z] = dep[q];
      int r = val[p] >> (63 - dep[z]) & 1;
      ch[z][r] = merge(ch[z][r], p);
    } else {
      z = ++tot;
      dep[z] = sd;
      ch[z][val[p] >> (63 - dep[z]) & 1] = p;
      ch[z][val[q] >> (63 - dep[z]) & 1] = q;
    }
    val[z] = val[ch[z][1]];
    siz[z] = siz[ch[z][0]] + siz[ch[z][1]];
    return z;
  }
  uint64_t getmin(int p) {
    while (dep[p] < 64) {
      if (siz[ch[p][0]]) p = ch[p][0];
      else p = ch[p][1];
    }
    return val[p] ^ (1ull << 63);
  }
}
namespace td = compressed_trie;
int n, m, rt[N];
LL a[N];
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);  
#endif
  cin >> n >> m;
  for (int i = 1; i <= n; i++) cin >> a[i], rt[i] = td::newnode(a[i], +1);
  while (m--) {
    int op, x, y;
    cin >> op >> x;
    if (op == 0) {
      cin >> y;
      rt[x] = td::merge(rt[x], td::newnode(a[y], -1));
    } else if (op == 1) {
      cout << (LL)td::getmin(rt[x]) << endl;
    } else if (op == 2) {
      cin >> y;
      rt[x] = td::merge(rt[x], rt[y]);
    } else {
      cin >> y;
      rt[x] = td::merge(rt[x], td::newnode(a[y], -1));
      cin >> a[y];
      rt[x] = td::merge(rt[x], td::newnode(a[y], +1));
    }
  }
  return 0;
}