猫树

· · 算法·理论

猫树是一种很巧妙的数据结构,可以高效支持静态区间的询问。

在学习本文之前,请先掌握线段树。

限于个人水平,文中可能存在不足之处,望不吝批评指正。

猫树简介

猫树本质是一棵线段树。但相比于线段树,猫树不支持修改,且空间复杂度为更劣的 \mathcal{O}(n \log n)

不过猫树也有其优点。相比于 ST 表,猫树可以处理所有满足结合律的运算(包括不允许重复贡献的运算);相比于线段树,猫树的单次查询复杂度更优。

例如,需要处理区间矩阵乘法或线性基的时候,由于其单次合并代价高,使用猫树可以将询问时的合并次数减少到 \mathcal{O}(1),优于线段树的 \mathcal{O}(\log n) 次。

建树(Build)

先来讲猫树如何建树。与普通线段树类似,也是不断把大区间分成两个子区间,递归建树。

不过猫树需要额外维护每个区间到其 mid 的信息。

具体地,假设建树时我们所在节点对应区间 [l, r],中点为 mid。我们需要维护 [l, mid] 中每个点到 mid 的一段后缀信息,以及 (mid, r] 中每个点到 mid + 1 的一段前缀信息

例如对于区间 [1, 5],其 mid = 3,维护的运算为加法。此时对于这个节点我们需要维护 1 \sim 32 \sim 33 \sim 33 个区间的和,以及 4 \sim 44 \sim 52 个区间的和。

由于每一层都需要维护 \mathcal{O}(n) 的信息,因此总空间复杂度为 \mathcal{O}(n \log n),劣于线段树。

不过直接这样维护遇到某些题是过不去的,因此还有一个技巧:一个点不可能同时需要维护前缀或后缀信息(因为一个点与 mid 的关系是唯一确定的)。具体地,我们可以直接把 presuf 压成同一个数组,因为同一个位置在两个数组中只有一个有值。压完后空间可以减少一半,这也是通过P11265的关键之一。

具体地,可以令 t_{d, i} 表示在猫树的第 d 层中,位置 i(下标 i)到其 mid 的区间信息(方向不确定),这样就可以省一半空间了,具体实现后文会提到。

特别的,建树时可以把序列长度 n 扩展到 2 的整数次幂(多余的位置用单位元填充),这样就不需要存节点编号了(后文也会提到)。

查询(Query)

下面讲如何查询。假设我们想要查询的区间为 [l, r],且树上对应 [l, l] 的节点和对应 [r, r] 的节点的 LCA 为 pp 对应的区间为 [L, R],区间 [L, R] 的中点为 mid。则有以下两条性质:

根据上述性质,我们一定可以将区间 [l, r] 拆成 [l, mid](mid, r] 两个区间,而这两个区间的信息在建树时维护过!因此,我们做到了 \mathcal{O}(1) 次合并。

不过还有一个细节问题,就是如何确定 l, r 的 LCA 呢?这里我们假设已经把长度 n 变为 2 的整数次幂,且下标范围 [0, n)。此时根据我们对 t 数组的定义,只需要求出其 LCA 的层数 d,就可以合并 t_{d, l}t_{d, r} 回答询问了。

那么如何求 d 呢?先列表格找规律:

下标 二进制 d = 0 d = 1 d = 2
0 000
1 001
2 010
3 011
4 100
5 101
6 110
7 111

不难发现,在第 d 层,两个下标是否在同一个半块,完全由它们二进制的第 d 位是否相同决定。

对于 l < r,其二进制最高不同位一定是 l0r1。而 lr 的异或结果中,这一位恰好是最高位的 1

因此在 > d 的层(更高位)上,lr 的那些位完全相同,所以它们一直处在同一个半块中,从未被切分开;在第 d 层,它们的第 d 位不同,一个在左半块,一个在右半块,这是它们首次被分割到不同半块,即 LCA 的位置。

因此,使用 31 - __builtin_clz(l ^ r) 的值,可以快速计算 d

参考实现

这里给出P11265 【模板】静态区间半群查询的参考实现。

代码中,buildquery 函数是重点,注意矩阵乘法不满足交换律

注意本题中如果 presuf 数组不合并,则空间会爆炸。

但是,我们知道 pre_{d, i}suf_{d, i} 中只有一个值有意义,因此直接令 t_{d, i} 表示有意义的那个值即可。

namespace Solution{
    int n, m, b; ull sd;
    inline ull splitmix64(ull x){
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }
    inline ull rnd(){
        sd ^= sd << 13, sd ^= sd >> 7;
        return sd ^= sd << 17;
    }
    const int inf = 1e9 + 114514;
    struct Matrix{
        int a[2][2];
        Matrix() : a({{0, inf}, {inf, 0}}) {}
        Matrix(int x, int y, int z, int w) : a({{z, y}, {x, w}}) {}
        Matrix& operator = (Matrix o){
            fo(i, 0, 1) fo(j, 0, 1) a[i][j] = o.a[i][j];
            return *this;
        }
        Matrix operator * (Matrix o) const{
            Matrix res;
            fo(i, 0, 1){
                fo(j, 0, 1){
                    int Min = inf;
                    fo(k, 0, 1){
                        Min = min(Min, a[i][k] + o.a[k][j]);
                    }
                    res.a[i][j] = Min;
                }
            }
            return res;
        }
        Matrix& operator *= (Matrix o){
            *this = *this * o;
            return *this;
        }
    }a[1 << 20], t[20][1 << 20]; // 对于每个位置,pre 和 suf 仅有一个生效,故合并为 t
    inline void genmat(Matrix& mat, ull x){
        fo(i, 0, 1) fo(j, 0, 1) mat.a[i][j] = x >> ((i << 1 | j) << 4) & 255;
        return;
    }
    inline void genqry(int& l, int& r, int n){
        if((rnd() & 1) && b){
            int c = rnd() % (n - b);
            l = rnd() % (n - c) + 1, r = l + c;
        }else{
            l = rnd() % n + 1, r = rnd() % n + 1;
            if (l > r) swap(l, r);
        }
        return;
    }
    inline int trans(Matrix x, Matrix y){
        int res = 0;
        fo(i, 0, 1) fo(j, 0, 1) res += x.a[i][j] ^ y.a[i][j];
        return res;
    }
    inline void build(int l, int r, int d){
        if(l == r) return;
        int mid = (l + r) >> 1;
        t[d][mid] = a[mid]; fd(i, mid - 1, l) t[d][i] = a[i] * t[d][i + 1];
        t[d][mid + 1] = a[mid + 1]; fo(i, mid + 2, r) t[d][i] = t[d][i - 1] * a[i];
        build(l, mid, d - 1), build(mid + 1, r, d - 1);
        return;
    }
    inline Matrix query(int l, int r){
        if(l == r) return a[l - 1];
        l--, r--;
        int d = 31 - __builtin_clz(l ^ r);
        return t[d][l] * t[d][r];
    }
    inline void Solve(){
        rd(n, m, sd, b), sd = splitmix64(sd);
        int x, y, z, w; rd(z, y, x, w); Matrix kv(x, y, z, w);
        fo(i, 1, n) genmat(a[i], rnd());
        // fo(_, 1, n){
        //     fo(i, 0, 1){
        //         fo(j, 0, 1){
        //             cerr << a[_].a[i][j] << ' ';
        //         }
        //         cerr << '\n';
        //     }
        //     cerr << '\n';
        // }
        int N = n, k = 0; n = 1; while(n < N) n <<= 1, k++;
        fo(i, 1, N) a[i - 1] = a[i]; fu(i, N, n) a[i] = Matrix();
        build(0, n - 1, k - 1);
        int ans = 0;
        while(m--){
            int l, r; genqry(l, r, N);
            // cerr << l << ' ' << r << '\n';
            Matrix res = query(l, r);
            ans ^= trans(res, kv);
        }
        wr(ans), pc('\n');
        return;
    }
}

例题

有些例题猫树不是唯一解,但如果想学好猫树,请自觉。

GSS1 - Can you answer these queries I

:::info[题意(加强版)]{open}

维护长为 n 的序列 a,支持 q 次查询,每次询问给出 l,r,求区间 [l, r] 的最大子段和(可以空)。

数据范围:1 \leq n \leq 5 \times 10^41 \leq q \leq \color{red}3 \times 10^7|a_i| \leq 15007

:::

线段树处理这么多询问不太好过,而且不带修改,因此使用猫树。

最大子段和可以通过维护区间和、最大前缀和、最大后缀和以及最大子段和这 4 个信息合并。

于是我们认为需要维护的信息是四元组,并且我们的“运算”相当于普通线段树的 push_up(满足结合律)即可。

:::info[参考代码]

namespace Solution{
    int n, q, x;
    struct Node{ // 最大子段和需要维护的 4 个信息
        int sum, lmx, rmx, ans;
        Node operator + (Node o) const{
            Node res = {0, 0, 0, 0};
            res.sum = sum + o.sum;
            res.lmx = max(lmx, sum + o.lmx);
            res.rmx = max(rmx + o.sum, o.rmx);
            res.ans = max({ans, o.ans, rmx + o.lmx});
            return res;
        }
    }a[1 << 16], t[16][1 << 16];
    inline void build(int l, int r, int d){
        if(l == r) return;
        int mid = (l + r) >> 1;
        t[d][mid] = a[mid]; fd(i, mid - 1, l) t[d][i] = a[i] + t[d][i + 1];
        t[d][mid + 1] = a[mid + 1]; fo(i, mid + 2, r) t[d][i] = t[d][i - 1] + a[i];
        build(l, mid, d - 1), build(mid + 1, r, d - 1);
        return;
    }
    inline Node query(int l, int r){
        if(l == r) return a[l - 1];
        l--, r--;
        int d = 31 - __builtin_clz(l ^ r);
        return t[d][l] + t[d][r];
    }
    inline void Solve(){
        rd(n); fo(i, 1, n) rd(x), a[i] = {x, x, x, x};
        int N = n, k = 0; n = 1; while(n < N) n <<= 1, k++;
        fo(i, 1, N) a[i - 1] = a[i]; fu(i, N, n) a[i] = {0, 0, 0, 0};
        build(0, n - 1, k - 1);
        rd(q); while(q--){
            int l, r; rd(l, r);
            wr(query(l, r).ans), pc('\n');
        }
        return;
    }
}

:::

P11265 【模板】静态区间半群查询

:::info[题意]{open}

给定一个序列 a_1,a_2,\cdots,a_n,其中的每个元素都是一个 2\times2 的矩阵。你需要处理 m 次查询,每次查询给定一个区间 [l,r],你需要求出 \prod_{i=l}^ra_i,其中 \times 符号代表 (\min,+) 矩阵积。

数据范围:1 \leq n,m \leq 10^6,1s 足够,512 MB。

:::

猫树板子题,(\min, +) 矩乘显然满足结合律,于是直接维护即可。

注意本题中如果 presuf 数组不合并,则空间会爆炸。

但是,我们知道 pre_{d, i}suf_{d, i} 中只有一个值有意义,因此直接令 t_{d, i} 表示有意义的那个值即可。

:::info[参考代码]

#include<bits/stdc++.h>
// #define int long long
#define fo(i, l, r) for(decltype((l) + (r)) i = (l); i <= (r); ++i)
#define fd(i, l, r) for(decltype((l) + (r)) i = (l); i >= (r); --i)
#define fu(i, l, r) for(decltype((l) + (r)) i = (l); i <  (r); ++i)
#define y1 zhang_kevin
#define pii pair<int, int>
#define fi first
#define se second
#define vec vector
#define pb push_back
#define eb emplace_back
#define all(v) v.begin(), v.end()
#define ll long long
#define ull unsigned long long
#define flush() (fwrite(obuf, 1, p3 - obuf, stdout), p3 = obuf)
using namespace std;
bool ST;
char ibuf[1 << 20], *p1 = ibuf, *p2 = ibuf, obuf[1 << 20], *p3 = obuf;
inline char gc(){
    if(p1 == p2){
        p1 = ibuf, p2 = ibuf + fread(ibuf, 1, 1 << 20, stdin);
        if(p1 == p2) return EOF;
        return *p1++;
    }
    return *p1++;
}
inline char pc(char ch){
    if(p3 == obuf + (1 << 20)) flush();
    *p3 = ch;
    return *p3++;
}
template<typename type>
inline int rd(type &x){
    x = 0; bool f = 0; char ch = gc();
    while(!isdigit(ch)) f |= ch == '-', ch = gc();
    while(isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = gc();
    return f ? x = -x : 0;
}
template<typename type, typename ...T>
inline void rd(type &x, T &...y){rd(x), rd(y...);}
inline void gs(string &s){
    s.clear(); char c = gc();
    while(c == ' ' || c == '\n' || c == '\t' || c == '\r') c = gc();
    while(c != ' ' && c != '\n' && c != '\t' && c != '\r' && c != EOF) s += c, c = gc();
    return;
}
class Flush{public: ~Flush(){flush();}}___;
template<typename type>
inline void wr(type x){
    if(x < 0) pc('-'), x = -x;
    if(x > 9) wr(x / 10);
    pc(x % 10 + '0');
    return;
}
inline void wrs(const string& s){for(auto ch : s) pc(ch);}
namespace Solution{
    int n, m, b; ull sd;
    inline ull splitmix64(ull x){
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }
    inline ull rnd(){
        sd ^= sd << 13, sd ^= sd >> 7;
        return sd ^= sd << 17;
    }
    const int inf = 1e9 + 114514;
    struct Matrix{
        int a[2][2];
        Matrix() : a({{0, inf}, {inf, 0}}) {}
        Matrix(int x, int y, int z, int w) : a({{z, y}, {x, w}}) {}
        Matrix& operator = (Matrix o){
            fo(i, 0, 1) fo(j, 0, 1) a[i][j] = o.a[i][j];
            return *this;
        }
        Matrix operator * (Matrix o) const{
            Matrix res;
            fo(i, 0, 1){
                fo(j, 0, 1){
                    int Min = inf;
                    fo(k, 0, 1){
                        Min = min(Min, a[i][k] + o.a[k][j]);
                    }
                    res.a[i][j] = Min;
                }
            }
            return res;
        }
        Matrix& operator *= (Matrix o){
            *this = *this * o;
            return *this;
        }
    }a[1 << 20], t[20][1 << 20]; // 对于每个位置,pre 和 suf 仅有一个生效,故合并为 t
    inline void genmat(Matrix& mat, ull x){
        fo(i, 0, 1) fo(j, 0, 1) mat.a[i][j] = x >> ((i << 1 | j) << 4) & 255;
        return;
    }
    inline void genqry(int& l, int& r, int n){
        if((rnd() & 1) && b){
            int c = rnd() % (n - b);
            l = rnd() % (n - c) + 1, r = l + c;
        }else{
            l = rnd() % n + 1, r = rnd() % n + 1;
            if (l > r) swap(l, r);
        }
        return;
    }
    inline int trans(Matrix x, Matrix y){
        int res = 0;
        fo(i, 0, 1) fo(j, 0, 1) res += x.a[i][j] ^ y.a[i][j];
        return res;
    }
    inline void build(int l, int r, int d){
        if(l == r) return;
        int mid = (l + r) >> 1;
        t[d][mid] = a[mid]; fd(i, mid - 1, l) t[d][i] = a[i] * t[d][i + 1];
        t[d][mid + 1] = a[mid + 1]; fo(i, mid + 2, r) t[d][i] = t[d][i - 1] * a[i];
        build(l, mid, d - 1), build(mid + 1, r, d - 1);
        return;
    }
    inline Matrix query(int l, int r){
        if(l == r) return a[l - 1];
        l--, r--;
        int d = 31 - __builtin_clz(l ^ r);
        return t[d][l] * t[d][r];
    }
    inline void Solve(){
        rd(n, m, sd, b), sd = splitmix64(sd);
        int x, y, z, w; rd(z, y, x, w); Matrix kv(x, y, z, w);
        fo(i, 1, n) genmat(a[i], rnd());
        // fo(_, 1, n){
        //     fo(i, 0, 1){
        //         fo(j, 0, 1){
        //             cerr << a[_].a[i][j] << ' ';
        //         }
        //         cerr << '\n';
        //     }
        //     cerr << '\n';
        // }
        int N = n, k = 0; n = 1; while(n < N) n <<= 1, k++;
        fo(i, 1, N) a[i - 1] = a[i]; fu(i, N, n) a[i] = Matrix();
        build(0, n - 1, k - 1);
        int ans = 0;
        while(m--){
            int l, r; genqry(l, r, N);
            // cerr << l << ' ' << r << '\n';
            Matrix res = query(l, r);
            ans ^= trans(res, kv);
        }
        wr(ans), pc('\n');
        return;
    }
}
bool ED;
signed main(){
    clock_t START = clock();
    // freopen("input.in", "r", stdin), freopen("output.out", "w", stdout);
    Solution::Solve();
    cerr << (double)(clock() - START) / CLOCKS_PER_SEC << " s" << '\n';
    cerr << 1.0 * abs(&ED - &ST) / 1024 / 1024 << " MB" << '\n';
    return 0;
}

:::

习题

参考文献