维护带修区间询问的利器——sqrt-tree

· · 个人记录

为方便起见,下文中提到的 size 为节点维护的区间长度。

引入

【模板】ST 表

题意:给定长度为 n 的序列,m 次询问,每次询问求 [l_i,r_i] 区间内的最大值,n \le 10^5m \le 2 \times 10 ^ 6

众所周知,区间 rmq 问题一般使用 ST 表解决。

但是,ST 表需要 O(n \log n) 的时间预处理,且并不支持所有满足结合律的计算。

那么现在对上面那道题目做一个加强:

由乃救爷爷

区间 rmq 问题,n,m \le 2 \times 10 ^ 7,时限 5s,空间 500 MB。

卡我 sqrt-tree 空间是吧

然后你会发现,ST 表预处理 O(n \log n) 的时间直接爆炸。

那么现在,我们想要一种预处理复杂度更低,而且询问复杂度仍然为 O(1) 的数据结构。

从分块出发

考虑对序列分块。

对于每个块,维护这个块的区间前缀答案 front_i、区间后缀答案 back_i,同时预处理出一个二维数组 bb_{i,j} 表示从第 i 个块到第 j 个块所构成的区间的答案。

显然这三个数组都是可以在 O(n) 的时间内求出的。

那么有了这三个数组的信息后,我们就可以 O(1) 回答所有跨块的询问,具体的,整块对答案的贡献在 b 数组中,而左右散块对答案的贡献则在 frontback 数组中。

询问时结合三段信息 O(1) 回答即可。

那不跨块的询问的?

废话啊,直接暴力啊,O(\sqrt{n})

直接暴力肯定是不可取的,因为上面那题数据随机,而数据随机=脚踩数据,所以没把你复杂度卡满,能过。

那我们怎么办呢?

考虑对这个块继续分块,这样就可以把一些在块内的询问就可以变成跨块询问,但是还有一些询问还在这个块分完的块中。

于是继续分块。

构建 sqrt-tree

不难想到这其实构成了一种树结构,每一个节点维护一条线段,叶子节点维护的线段的长度只能是 12

举个例子,如果我们要对一个长为 15 的序列 a 构建 sqrt-tree,那么这棵树的大体形状应该如下:

然后对每一个节点递推预处理出 frontbackb 数组,这是 O(size) 的。

因为一个维护区间长度为 k 节点的子节点有 \sqrt{k} 个,所以树高为 \log \log n,而每一层的区间总长度是 n,所以建树的时间复杂度是 O(n \log \log n) 的。


inline void make(int p , int l , int r , int now){
    int logdat = (floor_num[p] + 1) >> 1 , logcnt = floor_num[p] >> 1 , dat = 1 << logdat , cnt = (r - l + dat - 1) >> logdat;
    for(int i = 0;i < cnt;i++){
        sqrtnode ans;
        for(int j = i;j < cnt;j++){
            sqrtnode plusnum = back[p][l + (j << logdat)];
            if(i == j) ans = plusnum;
            else ans = op(ans , plusnum);
            b[p - 1][now + l + j + (i << logcnt)] = ans;
        }
    }
}
inline void makezero(){
    int logdat = (lgcnt + 1) >> 1;
    for(int i = 0;i < _index;i++){
        w[n + i] = back[0][i << logdat];
    }
    build(1 , n , n + _index , (1 << lgcnt) - n);
}

inline void build(int p , int l , int r , int now){
    if(p >= floor_num.size()) return ;
    int len = 1 << ((floor_num[p] + 1) >> 1);
    for(int i = l;i < r;i += len){
        int minn = min(i + len , r);
        makeblock(p , i , minn);
        build(p + 1 , i , minn , now);
    }
    if(!p){
        makezero();
    }
    else{
        make(p , l , r , now);
    }
}

sqrt-tree 的朴素询问

考虑最朴素的实现方式。

枚举树高,对于这一层的看是不是跨块询问,是跨块询问则 O(1) 回答并跳出,否则继续找。

时间复杂度 O(\log \log n)

sqrt-tree 的询问优化

直接把枚举树高变成二分树高即可。

时间复杂度 O(\log \log \log n)

sqrt-tree 的 O(1) 询问

先理解一个东西:我们找到了可以把询问变成跨块询问的树高就可以 O(1) 回答问题,方式见前。

首先,我们要对所有块的末尾添加一些元素,使其满足以下两个特征:

为了满足上述条件,我们需要在每一个块的后面添加一些不影响运算的值(比如在求区间最大值时添加极小值,在求区间和时添加 0)。

那为啥这样操作复杂度不会错呢?

因为每一个块的长度最多变为了原来长度的 2 倍,所以不影响复杂度。

现在我们来分析查询区间为 l,r 的询问如何 O(1) 确定树高。

我们知道每一个块的长度都是 2 的整数次方,设长度为 2 ^ k。所以每个块的左端点都是 2 ^ k 的倍数,右端点都是左端点加上 2 ^ k -1,所以在同一个块中的左右端点在二进制表示下仅有后 k 位不同。

因此我们只需要找到使得 lr 的后 k 位不同的最大的 k 即可。

然后显然就可以套一个异或操作上去,找最大的 k 就可以转化为找 l \oplus r 的最高位,直接暴力找即可。


inline sqrtnode query(int l , int r , int now , int k){
    if(l == r) return w[l];
    if(l == r - 1) return op(w[l] , w[r]);
    int p = on_floor[a[(l - k) ^ (r - k)]];
    int Log = (floor_num[p] + 1) >> 1;
    int cntlog = floor_num[p] >> 1;
    int ll = (((l - k) >> floor_num[p]) << floor_num[p]) + k;
    int blockl = ((l - ll) >> Log) + 1;
    int blockr = ((r - ll) >> Log) - 1;
    sqrtnode ans = back[p][l];
    if(blockl <= blockr){
        sqrtnode plus;
        if(!p){
            plus = query(n + blockl , n + blockr , (1 << lgcnt) - n , n);
        }
        else{
            plus = b[p - 1][now + ll + (blockl << cntlog) + blockr];
        }
        ans = op(ans , plus);
    }
    ans = op(ans , front[p][r]);
    return ans;
}

sqrt-tree 的单点修改

我们来考虑对于一次暴力的单点修改操作 a_p=x,sqrt-tree 里面哪些值会更新。

考虑一个长度为 s 的序列,容易发现 front 以及 back 只有 \sqrt{s} 个元素被更改了,修改的瓶颈在于根节点的 b 数组有 s 个元素被更新了。

这时,我们引出一个非常牛逼的东西——使用第二个 sqrt-tree 来代替根节点的 b 数组,我们称这个 sqrt-tree 为 id 树。注意,只有根节点的 b 数组用 sqrt-tree 来替代。

于是我们的单点修改是这样的:

  1. O(\sqrt{n}) 的复杂度更新 frontback 数组。
  2. 更新 id 树,虽然它的长度是 O(n) 的,但是我们只要修改一个元素,也是 O(\sqrt{n})
  3. 使用暴力更新的方法更新被修改的那个元素。

inline void repair(int p , int l , int r , int now , int x){
    if(p >= floor_num.size()) return ;
    int Log = (floor_num[p] + 1) >> 1;
    int date = 1 << Log;
    int blockid = (x - l) >> Log;
    int ll = l + (blockid << Log) , rr = min(ll + date , r);
    makeblock(p , ll , rr);
    if(!p){
        updatezero(blockid);
    }
    else{
        make(p , l , r , now);
    }
    repair(p + 1 , ll , rr , now , x);
}

inline void update(int pos , const sqrtnode &x){
    w[pos] = x;
    repair(0 , 0 , n , 0 , pos);
}

拓展:sqrt-tree 的区间修改

没错,sqrt-tree 还支持区间推平操作。

我们引入懒标记思想,在更新的时候在 sqrt-tree 上打懒标记。

然后根据懒标记的下传方式有两种实现方法:

第一种实现

只给第一层的节点打标记。

  1. 对于那些整块修改的节点,在这里打一个懒标记。
  2. 对于左右两边的被部分修改的散块,直接用 O(\sqrt{n} \log \log n) 的时间重构这两个块,如果这两个块之前有懒标记,就下传。

这时候,询问的方式也要改了:

  1. 如果询问被完全包含在一个块内,可以利用懒标记来算答案。
  2. 如果询问是跨块的,那么只要关心左散块和右散块的答案,因为中间的整块可以使用 id 树来计算答案(因为每次询问后 id 树重建了),复杂度 O(1)

所以询问复杂度仍是 O(1)

第二种实现

对于每一个节点打懒标记。

  1. 被整块包含在修改区间内的块,O(\sqrt{n}) 打懒标记。
  2. 被部分覆盖的散块,O(\sqrt{n}) 更新它们的 frontback 数组。
  3. 对于根节点结构不为 id 树式结构的子树,更新他们的 b 数组。
  4. 用上述方式(递归)更新两个没有被全覆盖的区间。

时间复杂度 O(\sqrt{n})

询问时要考虑这个节点祖先的懒标记,询问复杂度变为 O(\log \log n)

实现

区间推平操作懒得写了。

复杂度:建树 O(n \log \log n),询问 O(1),单点修改 O(\sqrt{n})

本代码为查询区间最大值。

#include <bits/stdc++.h>
using namespace std;
struct sqrtnode{
    int x;
    bool operator <(const sqrtnode &b) const{
        return x < b.x; 
    }
};
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
    return x*f;
}

sqrtnode op(const sqrtnode &a , const sqrtnode &b);
inline int ceil_log2(int n){
    int cnt = 0;
    while((1 << cnt) < n){
        cnt++;
    }
    return cnt;
}
class sqrt_tree{
    private:
        int n , lgcnt , _index;
        vector <sqrtnode> w;
        vector <int> a , floor_num , on_floor;
        vector < vector <sqrtnode> > front , back , b;
        inline void makeblock(int p , int l , int r){
            front[p][l] = w[l];
            for(int i = l + 1;i < r;i++){
                front[p][i] = op(front[p][i - 1] , w[i]);
            }
            back[p][r - 1] = w[r - 1];
            for(int i = r - 2;i >= l;i--){
                back[p][i] = op(back[p][i + 1] , w[i]);
            }
        }
        inline void make(int p , int l , int r , int now){
            int logdat = (floor_num[p] + 1) >> 1 , logcnt = floor_num[p] >> 1 , dat = 1 << logdat , cnt = (r - l + dat - 1) >> logdat;
            for(int i = 0;i < cnt;i++){
                sqrtnode ans;
                for(int j = i;j < cnt;j++){
                    sqrtnode plusnum = back[p][l + (j << logdat)];
                    if(i == j) ans = plusnum;
                    else ans = op(ans , plusnum);
                    b[p - 1][now + l + j + (i << logcnt)] = ans;
                }
            }
        }
        inline void makezero(){
            int logdat = (lgcnt + 1) >> 1;
            for(int i = 0;i < _index;i++){
                w[n + i] = back[0][i << logdat];
            }
            build(1 , n , n + _index , (1 << lgcnt) - n);
        }
        inline void updatezero(int id){
            int Log = (lgcnt + 1) >> 1;
            w[n + id] = back[0][id << Log];
            repair(1 , n , n + _index , (1 << lgcnt) - n , n + id);
        }
        inline void build(int p , int l , int r , int now){
            if(p >= floor_num.size()) return ;
            int len = 1 << ((floor_num[p] + 1) >> 1);
            for(int i = l;i < r;i += len){
                int minn = min(i + len , r);
                makeblock(p , i , minn);
                build(p + 1 , i , minn , now);
            }
            if(!p){
                makezero();
            }
            else{
                make(p , l , r , now);
            }
        }
        inline void repair(int p , int l , int r , int now , int x){
            if(p >= floor_num.size()) return ;
            int Log = (floor_num[p] + 1) >> 1;
            int date = 1 << Log;
            int blockid = (x - l) >> Log;
            int ll = l + (blockid << Log) , rr = min(ll + date , r);
            makeblock(p , ll , rr);
            if(!p){
                updatezero(blockid);
            }
            else{
                make(p , l , r , now);
            }
            repair(p + 1 , ll , rr , now , x);
        }
        inline sqrtnode query(int l , int r , int now , int k){
            if(l == r) return w[l];
            if(l == r - 1) return op(w[l] , w[r]);
            int p = on_floor[a[(l - k) ^ (r - k)]];
            int Log = (floor_num[p] + 1) >> 1;
            int cntlog = floor_num[p] >> 1;
            int ll = (((l - k) >> floor_num[p]) << floor_num[p]) + k;
            int blockl = ((l - ll) >> Log) + 1;
            int blockr = ((r - ll) >> Log) - 1;
            sqrtnode ans = back[p][l];
            if(blockl <= blockr){
                sqrtnode plus;
                if(!p){
                    plus = query(n + blockl , n + blockr , (1 << lgcnt) - n , n);
                }
                else{
                    plus = b[p - 1][now + ll + (blockl << cntlog) + blockr];
                }
                ans = op(ans , plus);
            }
            ans = op(ans , front[p][r]);
            return ans;
        }
    public:
        inline sqrtnode query(int l , int r){
            return query(l , r , 0 , 0);
        }
        inline void update(int pos , const sqrtnode &x){
            w[pos] = x;
            repair(0 , 0 , n , 0 , pos);
        }
        sqrt_tree(){

        }
        sqrt_tree(const vector <sqrtnode> &qwq) : n(qwq.size()) , lgcnt(ceil_log2(n)) , w(qwq) , a(1 << lgcnt) , on_floor(lgcnt + 1){
            a[0] = 0;
            for(int i = 1;i < a.size();i++){
                a[i] = a[i >> 1] + 1;
            }
            int tmplg = lgcnt;
            while(tmplg > 1){
                on_floor[tmplg] = floor_num.size();
                floor_num.push_back(tmplg);
                tmplg = (tmplg + 1) >> 1;
            }
            for(int i = lgcnt - 1;i >= 0;i--){
                on_floor[i] = max(on_floor[i] , on_floor[i + 1]);
            }
            int maxn = max(0 , (int)floor_num.size() - 1);
            int Log = (lgcnt + 1) >> 1;
            int noww = 1 << Log;
            _index = (n + noww -1) >> Log;
            w.resize(n + _index);
            front.assign(floor_num.size() , vector <sqrtnode> (n + _index));
            back.assign(floor_num.size() , vector <sqrtnode> (n + _index));
            b.assign(maxn , vector <sqrtnode> ((1 << lgcnt) + noww));
            build(0 , 0 , n , 0);
        }
};
sqrtnode op(const sqrtnode &a , const sqrtnode &b){
    return (a < b ? b : a);
}
signed main(){
//  freopen(".in" , "r" , stdin);
//  freopen(".out" , "w" , stdout);
    sqrtnode tmp;
    vector <sqrtnode> a;
    int n = read() , q = read();
    for(int i = 1;i <= n;i++) tmp.x = read() , a.push_back(tmp);
    sqrt_tree t(a);
    while(q--){
        int l = read() , r = read();
        cout << t.query(l - 1 , r - 1).x << '\n';
    }
    return 0;
}

当然,sqrt-tree 还有这样那样的不足 (比如空间),欢迎在评论区提出解决方案。

当然了,对于空间的优化,我们可以限制 sqrt-tree 的递归分块层数,虽然说慢了一点,但是复杂度还是十分优秀,最后就做到的空间可接受。

其实 sqrt-tree 的结构和线段树是比较类似的,这个应该就是和线段树的底层分块差不多的。

所以有些性质我们就可以搬到 sqrt-tree 上,这里就不做过多的展开了。

至于 sqrt-tree 的优势,就在于建树的 O(n \log \log n) 和 询问的 O(1) 的优秀复杂度了。

最后再说一句,这玩意好像和多叉线段树比较的像,所以有可能支持一些别的操作,也欢迎各位大佬在评论区提出。

参考文献:oi-wiki sqrt-tree