Wavelet Matrix / Tree 更快更好地解决区间第k小问题

· · 个人记录

Wavelet Matrix / Tee

小波(波纹疾走?)树是一个非常优雅的数据结构。尽管已经被提出快20年,但是国内还是很少有人用他去解决算法竞赛中的问题,多用于学术界。常用于解决区间查询,以及求区间分割点(Quantile,分位点)的问题。属于"简洁数据结构"(Succinct Data Structure) 的一种。Wavelet是一种信号的形状,很像波浪。很遗憾Wiki也没有关于Wavelet Tree的中文界面,所以我习惯性叫它波浪形态树。

Wavelet Tree的复杂度是基于字母表σ的,因此如果我们要考虑的不只是字符串,而是数列,这个时间和空间消耗会非常恐怖。因此我们引入了matrix的方法来优化,时间和空间都更加优秀。

Wavelet Matrix在算法竞赛中是朴素Wavelet Tree的上位替代,因此我们着重介绍前者。

前置知识:基数排序,稳定排序,位运算基础。

太长不看版

如果你对原理没有什么兴趣,可以直接抄走我抄来的模板,用标注的API解决问题即可。

本文的例子来自 https://miti-7.hatenablog.com/entry/2018/04/28/152259 这是最详细的Wavelet Matrix的讲解文章,没有之一。

本文大量引用了该博客的图片,基于MIT License,我可以对进行再创作。

本文不会对实现做特别多的解释,仅抛砖引玉。

能解决的问题

区间第k小

区间某个数出现的频率。

区间小于等于某个数的个数

...

以上都是 log(σ)的复杂度

初始化

坐标是数值的最高位bit进行稳定排序生成的。

给定数列,T = [5, 4, 5, 5, 2, 1, 5, 6, 1, 3, 5, 0]

我们只考虑前 3 位作为简单情况,因为最高值是 5(101)_23 位足够了。

  1. 我们取所有数字的最高位bit,可以做成下图的形式。

  1. T1 进行稳定排序(stable_partition,C++自带的API),将 T1,也就是最高位bit是0的放到左边,T1里有 [2, 1, 1, 3] 四个数字的最高位是0,因此我们可以生成 T2

    蓝色的线就是我们的分割点,此时我们取所有数字的第二位,2 的二进制表示是 010,因此是 01 的二进制表示是 001,所以放 0,以此类推,生成了 B2

  2. 将T2基于B2进行稳定排序,也就是把 [1, 1, 0, 5, 4, 5, 5, 5, 5] 放到了左侧,其余第二高位Bit位放到了右侧,生成了T3

    我们取所有数字的第三高位,其实也就是最后一位Bit位。制成了 B3

  3. 同理,我们利用 B3,将 T3分为 [0, 4, 2, 6][1, 1, 5, 5, 5, 5, 3] 两部分。

    我们记录每一个数值开始的坐标,因为所有字符已经连续。

    最终我们获得了以下的坐标。 ![f:id:MitI_7:20180426193432p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193432.png)

具体操作

access

我们从例子入手: $T = [5, 4, 5, 5, 2, 1, 5, 6, 1, 3, 5, 0]$ 的时候,我们求 $access(7) ![f:id:MitI_7:20180426193643p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193643.png) 1. 我们先来看看 $B1[7]$,$B1$ 本身就是从最原始的数列 $T1$ 转换来的,所以我们很明白这个 $B1[7]$ 就是储存的5的第一位。 2. 如何通过 $B1[7] == 1$ 这个信息,找到对应在 $B2$ 里的信息呢? 考虑下我们是如何从 $B1$ 转换到 $B2$ 的。首先 $B1$ 里值为 $0$ 的话会归类到 $B2$ 的左侧,值为 $1$ 的话是 $B2$右侧。 之前我们得到 $B1[7]$ 是1,因此可以推断: $B1$里 $0$ 的个数 + $B1$ 到 $7$ 这个位置1的个数,其实就是 $B1[7]$ 移动到的位置。 我们实际来看,$B1[7]$ 包括自己和前面一共有 $5$ 个 $1$,记为 $rank_1(B1, 7)$,整个 $B1$ 有 $5$个 $0$,就是第 $10$ 位。$B2[10]$ 就是我们要找的信息,等于 $0$。因此我门可以反推得到 $T[7]$ 的第二高位的Bit是 $0$。记为 ![f:id:MitI_7:20180426193712p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193712.png) 3. 同样地,基于 $B2[10] = 0$, 我们找到 $B3$ 里他的位置,$0$ 会被分到 $B3$ 的左侧,因此我们只需要得到 $B2$ 中,$B2[10]$ 及其之前有多少个0, 记为 $rank_0(B2, 10) = 8$,就可以推出在 $B3$ 中的位置了,也就是 $8$ 个。 ![f:id:MitI_7:20180426193753p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193753.png) 4. 我们得到 $B3[8] = 1$ 。因此三位Bit全部得到,所以 $T[7] = 101$。 总结一下我们得到的三步 ![f:id:MitI_7:20180426193815p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193815.png) #### rank $rank$ 是求 数列 $T$ 的 $i$ 位置前,数值 $c$ 出现的次数的函数。 比如说 $T=[5, 4, 5, 5, 2, 1, 5, 6, 1, 3, 5, 0] $,$rank_5(T, 9)$ ,就是求前 $9$位( $1$ - based)里 $5$的个数,也就是 $4$个。 1. 红线前面就是我们区间的范围。目标是求左侧 $5$ 出现的次数。 还记得我们的 $T4$ 长啥样吗?所有一样的数字会被归类,而且根据stable sort的特性,所有元素的相对位置还会被保留。 因此其实我们只要 $9$ 之前**最后一个 $5$**( $idx = 7$ )在 $T4$ 的位置,就能找到前 $9$ 所有 $5$ 的个数了。 ![f:id:MitI_7:20180426193925p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193925.png) 2. $5$ 的最高位bit是 $1$,因此我们要考虑 $B1$ 中最高位也是 $1$ 的部分会如何转移到 $B2$ 里去。考虑到 $B1$ 里 $0$的个数是 $5$,$rank_1(B1, 9) = 6$,在 $B2$ 里,所有最高位Bit为 $1$ 的数字,都会在区间 $[6, 11]$ 中,我们找最后一个最高Bit位是 $1$ 的,也就是 $11$的位置。如果你忘了为什么要找最后一个最高Bit位是 $1$,可以看看上一条。 ![f:id:MitI_7:20180426193954p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426193954.png) 3. $5$ 的第二最高Bit位是 $0$,考虑 $B2$ 红线内里是 $0$ 的数字,因为 $B2$ 前 $11$ 个的数字里有 $8$ 个 $0$,我们找最后一个,所以我们就该找 $B3$ 的 $8$了。 ![f:id:MitI_7:20180426194037p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194037.png) 4. $5$ 的最后一位Bit位是 $1$,因此我们首先考虑 $B3$ 里有 $4$ 个 $0$,肯定是在 $1$ 的前面,再考虑 $B3$ 的 $8$ 及其前部有多少个 $1$,即 $rank_1(B3, 8) = 6$, 加起来得 $10$。就是我们 $T4$里的坐标。 ![f:id:MitI_7:20180426194111p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194111.png) 我们在构筑这个表的时候,记录了每个数字最早出现的坐标,因此我们可以得到 $10 - 7 + 1= 4$。$rank_5(T, 9) = 4$ 得出结论。 ​ ![f:id:MitI_7:20180426194126p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194126.png) #### select $select$ 是求数列 $T$ 中第 $i$ 个数值 $c$ 的位置的函数。 比如说 $T = [5, 4, 5, 5, 2, 1, 5, 6, 1, 3, 5, 0]$ , 求 $select_5(T, 4)$。很明显第4个5是在位置7的地方。$select$本质上其实是 $rank$ 的逆运算。 1. 我们从 $T4$ 开始看,因为我们已经预处理出 $7$ 是 $5$开始的位置,那么很容易求出 $4$开始的位置是在 $7 + 4 - 1 = 10$。我们的目标是从$T4$反推出 $T4[10]$ 在 $B1$ 里的位置。 ![f:id:MitI_7:20180426194554p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194554.png) 2. $T4[10]$ 是从 $B3$ 的哪里转移过来的呢? $5$ 的最低位Bit是 $1$,考虑 $B3$ 中有 $4$ 个 $0$,则有 $4 + rank_1(B3, x) = 10$, 问题转换成求$select_1(B3, 10 - 4)$,答案是 $8$。 ![f:id:MitI_7:20180426194517p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194517.png) 3. $B3[8]$ 是从 $B2$ 的哪个位置转移过来的呢?考虑到 $5$ 的第二高位Bit是 $0$ ,我们要找到$rank_0(b2, x) = 8$ 即 $select(B2,8)$,得到 $10$. ![f:id:MitI_7:20180426194434p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194434.png) 4. 最后一步考虑 $B2[10]$ 是从 $B1$ 哪个位置转移过来的,因为 $5$ 最高位是 $1$ ,问题转化成 $5 + rank(b1 ,x) = 10$,即 $select(B1, 10 - 5)$,得到答案 $7$。也就是我们原问题的答案。 ![f:id:MitI_7:20180426194345p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194345.png) 回顾一下整个过程,可以倒叙看。 ![f:id:MitI_7:20180426194607p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180426/20180426194607.png) #### quantile $quantile$ 是求数列 $T$ 的区间 $[l, r]$ 中第 $k$ 小的值,即区间第 $k$ 小问题。 举例说明 $T=[5, 4, 5, 5, 2, 1, 5, 6, 1, 3, 5, 0]$, ​求$quantile(T,2,11,8)

将区间 [2, 11] 进行排序,得到 [1, 1, 2, 3, 4, 5, 5, 5, 5, 6],取第 8 个,就是 5

quantile(T,2,11,8)=5

  1. 从原数列T1出发,要求其中的第8小的数字。

  2. 被红线圈住的 B1 里有 0,有 1,分别代表最高Bit位置都是什么。很明显 rank_0(B1, 11) = 4, rank_0(B1,1) = 0,那么我们有 4061,其中第 8 小的最高位肯定还是 1。我们所求的区间在 211 之间,我们思考一下会在 B2 里如何转移。红线内侧的第一个 1B2 里的位置如何得到呢?

    首先可以发现 rank_1(b1, 1) = 1,即左端红线外侧有一个 1,因此红线内部第一个 1,是 B1 整体的第二个 1

    因为 B1 中有 50,自然可以发现 B2 中的位置是 5 + 2 = 7.

    红线内侧最后一个 1,现在坐标是 11,在 B2 的位置如何求呢?

    只要计算区间 [1, 11] 内有多少1即可。rank_1(B1, 11) = 7,在 B2 中的位置就是 12

    注意,这里我们直接踢掉了前 4个数字,当前bit位是 0,所以踢掉的这几个数字已经不可能是我们需要的 k小数字了。

    新的红线区间就是 [5, 12]

  3. 意味着他们转移到 $B3$ 的时候会是 $[0, 0, 0, 0, 0, 1]$ 的顺序。然后我们考虑一下我们在 $B2$里要找哪个数字。 其实我们要求的东西已经变了,我们在已经在 $B1$ 里排除了 $4$个$0$,因此我们现在找的其实是第 $4$ 个数,也是 $0$ 。同时 $B2$ 的红线内的 $0$ 在 $B3$ 的 $ [5, 9]$ 区间。 ![f:id:MitI_7:20180428100552p:plain](https://cdn-ak.f.st-hatena.com/images/fotolife/M/MitI_7/20180428/20180428100552.png)
  4. 最后一步,我们发现红线里已经变成了 [0, 1, 1, 1, 1],上一步因为我们没法排除任何一个数,所以这里还是要取第4小的Bit。可以看出来是1,因此 [6, 9] 都可能是我们的答案。

    我们考虑映射到T4中,因为右侧有 20,会在下一轮排序跑到 [6, 9] 前面,因此我们新区间就是 [8, 11],我们要的数就是 5

    复习下流程:

本文没有提到的一些东西

很关键的就是,如何方便地管理上面所说的 B1, B2, B3

可以用另外一个简洁数据结构,他的实现比较简单,或许你一看下面的代码就懂,但是可以讲讲。

区间和可以求吗?可以。

可以套树状数组和线段树吗?也可以 可以修改吗?可以。 可以持久化吗?好像也可以 我们上面说的几个操作已经可以解决掉本题 区间静态第k小查询了。 ```cpp void solve() { int n, q; std::cin >> n >> q; std::vector<int> a(n); for (int & x : a) { std::cin >> x; } // 离散化 auto input = a; auto backup = a; std::sort(a.begin(), a.end()); a.erase(std::unique(a.begin(), a.end()), a.end()); auto find = [&](int x) { return std::lower_bound(a.begin(), a.end(), x) - a.begin(); }; for (int i = 0; i < n; i++) { input[i] = find(backup[i]); } // 建立WM和查询 WaveletMatrix wm(input); for (int i = 0; i < q; i++) { int l, r, k; std::cin >> l >> r >> k; l--; std::cout << a[wm.quantile(k - 1, l, r)] << '\n'; } } ``` ## 模板 抄模板没什么丢人的。 抄了就是我的(x),详细可以看看注释。 ```c++ // https://kopricky.github.io/code/DataStructure_Advanced/wavelet_matrix.html struct BitRank { // block 管理一行一行的bit std::vector<unsigned long long> block; std::vector<unsigned int> count; BitRank() {} // 位向量长度 void resize(const unsigned int num) { block.resize(((num + 1) >> 6) + 1, 0); count.resize(block.size(), 0); } // 设置i位bit void set(const unsigned int i, const unsigned long long val) { block[i >> 6] |= (val << (i & 63)); } void build() { for (unsigned int i = 1; i < block.size(); i++) { count[i] = count[i - 1] + __builtin_popcountll(block[i - 1]); } } // [0, i) 1的个数 unsigned int rank1(const unsigned int i) const { return count[i >> 6] + __builtin_popcountll(block[i >> 6] & ((1ULL << (i & 63)) - 1ULL)); } // [i, j) 1的个数 unsigned int rank1(const unsigned int i, const unsigned int j) const { return rank1(j) - rank1(i); } // [0, i) 0的个数 unsigned int rank0(const unsigned int i) const { return i - rank1(i); } // [i, j) 0的个数 unsigned int rank0(const unsigned int i, const unsigned int j) const { return rank0(j) - rank0(i); } }; class WaveletMatrix { private: unsigned int height; std::vector<BitRank> B; std::vector<int> pos; public: WaveletMatrix() {} WaveletMatrix(std::vector<int> vec) : WaveletMatrix(vec, *std::max_element(vec.begin(), vec.end()) + 1) {} // sigma: 字母表大小(字符串的话),数字序列的话是数的种类 WaveletMatrix(std::vector<int> vec, const unsigned int sigma) { init(vec, sigma); } void init(std::vector<int>& vec, const unsigned int sigma) { height = (sigma == 1) ? 1 : (64 - __builtin_clzll(sigma - 1)); B.resize(height), pos.resize(height); for (unsigned int i = 0; i < height; ++i) { B[i].resize(vec.size()); for (unsigned int j = 0; j < vec.size(); ++j) { B[i].set(j, get(vec[j], height - i - 1)); } B[i].build(); auto it = stable_partition(vec.begin(), vec.end(), [&](int c) { return !get(c, height - i - 1); }); pos[i] = it - vec.begin(); } } int get(const int val, const int i) { return val >> i & 1; } // [l, r) 中val出现的频率 int rank(const int val, const int l, const int r) { return rank(val, r) - rank(val, l); } // [0, i) 中val出现的频率 int rank(int val, int i) { int p = 0; for (unsigned int j = 0; j < height; ++j) { if (get(val, height - j - 1)) { p = pos[j] + B[j].rank1(p); i = pos[j] + B[j].rank1(i); } else { p = B[j].rank0(p); i = B[j].rank0(i); } } return i - p; } // [l, r) 中k小 int quantile(int k, int l, int r) { int res = 0; for (unsigned int i = 0; i < height; ++i) { const int j = B[i].rank0(l, r); if (j > k) { l = B[i].rank0(l); r = B[i].rank0(r); } else { l = pos[i] + B[i].rank1(l); r = pos[i] + B[i].rank1(r); k -= j; res |= (1 << (height - i - 1)); } } return res; } int rangefreq(const int i, const int j, const int a, const int b, const int l, const int r, const int x) { if (i == j || r <= a || b <= l) return 0; const int mid = (l + r) >> 1; if (a <= l && r <= b) { return j - i; } else { const int left = rangefreq(B[x].rank0(i), B[x].rank0(j), a, b, l, mid, x + 1); const int right = rangefreq(pos[x] + B[x].rank1(i), pos[x] + B[x].rank1(j), a, b, mid, r, x + 1); return left + right; } } // [l,r) 在[a, b) 值域的数字个数 int rangefreq(const int l, const int r, const int a, const int b) { return rangefreq(l, r, a, b, 0, 1 << height, 0); } int rangemin(const int i, const int j, const int a, const int b, const int l, const int r, const int x, const int val) { if (i == j || r <= a || b <= l) return -1; if (r - l == 1) return val; const int mid = (l + r) >> 1; const int res = rangemin(B[x].rank0(i), B[x].rank0(j), a, b, l, mid, x + 1, val); if (res < 0) return rangemin(pos[x] + B[x].rank1(i), pos[x] + B[x].rank1(j), a, b, mid, r, x + 1, val + (1 << (height - x - 1))); else return res; } // [l,r) 在[a,b) 值域内存在的最小值是什么,不存在返回-1 int rangemin(int l, int r, int a, int b) { return rangemin(l, r, a, b, 0, 1 << height, 0, 0); } }; ``` ## Reference 1. Wavelet Trees for Competitive Programming https://ioinformatics.org/journal/v10_2016_19_37.pdf 2. https://users.dcc.uchile.cl/~gnavarro/ps/spire12.4.pdf 3. https://rsk0315.hatenablog.com/entry/2022/01/09/152028