让我们再疯狂一点吧——线段树的高维实现

· · 算法·理论

故事是这样的 热狗贩子有一天想到一个奇怪的问题: 树状数组的高维实现只需要多套几个for循环就行了,那么同为动态区间管理的数据结构的线段树有没有什么高维的实现方法呢? 停!我知道你们肯定想要讲树套树!这个我们今天不谈,我们今天来讲讲另外一种 一层树就可以管理高维数组的线段树

\color{Red}\colorbox{White}{本文将会以三维数组为样例进行讨论,笔者将三维} \color{Red}\colorbox{White}{数组抽象为立方体,这样看起来更直观}

从一维线段树的管理模式谈起

一维线段树大家都熟悉,对于每一个管辖区间不为单个的节点,将它的区间切割为左区间和右区间分别交给它的两个儿子进行管理,它则依托 push_up 函数进行管理,我们以一个只包含求和功能,管理长度为五的数组的线段树为例(图中数字都是数值):

我们可以这样形象化它的 build 函数的过程:

简而言之,就是将自身的区间不断进行二分,直至分割为一个个单独的节点。

那么对于三维数组,情况又是怎么样呢?

假设我们有一个线段树上的节点,它管理着一个三维的数组的一个区间,我们用一个立方体表示:

先看 0 \sim n 的那条轴。根据线段树的常规操作应当是将 [0,n] 这个区间给一分两半,就像这样:

那么对于 0 \sim m , 0 \sim k ,这两条轴,则是同样的:

现在根据这三个被分割的轴,将这个管辖区域进行分割,就像这样:

你发现了什么?

我们成功分割出了八个该节点的儿子们的管辖区域!

由此,我们可以概括出:

对于 n 维的线段树,我们可以将它的每一维都分为两半,最后最多分割出 2^n 个孩子,它将分割出的 2^n 个区间交给它的 2^n 个孩子进行管理,自己再利用 push_up 等函数进行管理。

这就是高维线段树的基本运行原理。

探讨每一个节点的表示方法

在一维线段树中,我们用 lr 数组来存储每一个节点的管辖范围。

那么如果是二维线段树呢?很简单,加一个 ud 数组来管理。

但是如果是三维,四维,甚至五维呢?

你的耐心不够用了吧!

收到 MCfill 指令的启发,我们可以只存储一个节点管辖区间的一组对角的坐标。

那么用哪一组对角呢?

很简单,我们可以存储所有维坐标都最小的和所有维坐标都最大的一对点,比如在上图,就是这两个标绿的点:

我们再将 lr 两个数组进行改造,用来存储这两个点的坐标。这两个数组都设为二维,第一维表示节点编号,第二维存储坐标。

比如说,要想看三维空间中点 (x,y,z) 是否处于 a 号节点的管辖范围之类,只需要知道是否

l_{a,1} \leqslant x \leqslant r_{a,1} l_{a,2} \leqslant y \leqslant r_{a,2} l_{a,3} \leqslant z \leqslant r_{a,3}

这三个条件是否都符合。

对于更高维度的线段树,都是类似的。

它的孩子们......

是,这样划分孩子们是很不错,但是,如果简单地用向量或数组来存储孩子,如何确定每一个孩子管辖的区间的相对位置?

你这么想:每一维只被切割为两端,一段靠近原点,一段远离原点,这可以映射到二进制数的每一位上,只有两个状态:01

这是什么?没错这是 状态压缩

我们这样定义存储孩子的数组:

int children[节点数][2^维度数];

这样,我们就可以在确定孩子的同时确定位置!

比如,在本文的三维数组中,最为靠近原点的点状态压缩的下标为 000 的十进制,即 0 ;相对的,最远离原点的 111 ,十进制则为 7 ;在比如说 101 ,即十进制的 5 表示它的第二维靠近原点,第一维和第三维远离原点,以此类推。

这样子遍历儿子们也会变得特别方便, 我们只需要利用 for 循环从 0 遍历到 2^n-1 即可遍历一个节点的儿子们

基本代码实现就像这样:

for(int i=0;i<8;i++){
    int ch=children[x][i];//当前遍历到的孩子
    if(符合深搜条件){
        dfs(ch);
    }
}

细节,全是细节

说了这么多牛气哄哄的东西,但是还有一个很重要的问题没有解决:

如果在分割过程之中,一部分维度率先被分割得只剩一个单位长度了,如何分类讨论?

不要以为这是 if-else 判断分支可以简单解决的!每一维都有到和没到两种情况,如果是 n 维的话,这个分支情况可以达到恐怖的 2^n种情况,而且相信笔者,你没有那么能吃苦,打不到4个判断分支你就得崩溃!

那么怎么办呢?

我们再来看看这个分割区间的过程:

先切一刀。

假设左下角的这条轴上我们已经切到了1个单位长度,不能再进行切割,那怎么办?很简单! 不切了呗!

那么是不是只要一维一维地侦测,不能切就不切呢?

恭喜你,你答对了!

考虑创建一个数据结构 node 用来存储一个区间的范围,方法同上:

struct node{
    int left[维度+1];
    int right[维度+1];
};

接着创建两个 multiset 来捣鼓

multiset<node> st;
multiset<node> st2;

然后把这个大节点自己塞进 st

node llre;
for(int i=1;i<=wd;i++){
    llre.left[i]=left[i];
    llre.right[i]=right[i];
}
st.insert(llre);

接下来,第一层 for 循环是每一维度,如果探测到已经被切割成为1个单位长度,就跳过,如果不是, st 中的所有的区间都在这一维分割,分割出来储存进 st2 中:

for(multiset<node>::iterator j=st.begin();j!=st.end();j++){
    node tmp=*j,lre;
    for(int l=1;l<=wd;l++){
        lre.left[l]=tmp.left[l];
        lre.right[l]=tmp.right[l];
    }
    lre.left[i]=left[i];
    lre.right[i]=mid[i];
    st2.insert(lre);
    lre.left[i]=mid[i]+1;
    lre.right[i]=right[i];
    st2.insert(lre);
}

然后把 st 清空, st2 中的所有区间都塞到 st 中去。

st.clear();
for(multiset<node>::iterator j=st2.begin();j!=st2.end();j++){
    st.insert(*j);
    }
st2.clear();

如此往复循环,就完成了对当前区间的切割。至于一个个创建孩子,这就更简单了,利用状压确定当前的区间应被存储于哪一个下标之中,就像这样:

for(multiset<node>::iterator i=st.begin();i!=st.end();i++){
    node tmp=*i;
    int ll[4],rr[4],ch=0;
    for(int l=1;l<=维度数;l++){
        ll[l]=tmp.left[l];
        rr[l]=tmp.right[l];
    }//这个儿子的区间记录下来,等会塞进去
    for(int j=1;j<=wd;j++){
        if(ll[j]==left[j]){
            ch=(ch|(0<<(j-1)));
        }else{
            ch=(ch|(1<<(j-1)));
        }
    }//确定下标
        children[x][ch]=cnt++;
        build(children[x][ch],ll,rr);//build函数后两个参数为数组,因为是多维的
}

那么没有切的那一维呢,就默认全都状压的这一位为 0

其他线段树常用函数在高维时的表现

figure()

这个函数一般不会出现在一维线段树中,它的作用是统计当前节点管辖的元素个数,配合懒标记食用。

代码实现:

int figure(int x){
    int lre=1;
    for(int i=1;i<=维度数;i++){
        lre*=(r[x][i]-l[x][i]+1);
    }
    return lre;
}

push_up()

这很简单,以带懒标记的求和为例子,一个个遍历就行了,push_down 也一样。

代码实现:

void push_up(int x){
    sum[x]=0;
    for(int i=0;i<2^维度数;i++){
        if(children[x][i]==0){
        //这个孩子不存在
            continue;
        }
        sum[x]+=sum[children[x][i]]+delta[children[x][i]]*figure(children[x][i]);、
    }
} 

change()

保持和一维线段树一致的思路,如果被更改范围完全包含,就更改自身的懒标记,否则下潜。

代码实现:

void plu(int x,int left[],int right[],int val){
    if(被包含){
        delta[x]+=val;
        return;
    }

    for(int i=0;i<8;i++){
        int ch=children[x][i];
        if(这个孩子管辖范围有被包含的区域){
            plu(ch,left,right,val);
        }
    }

    push_up(x);
}

query()

作者个人认为这是高维线段树中最简单的部分,只需将下面的代码打上注释的两行进行更改,你就可以用它来查询任意你想要的!

int query_sum(int x,int left[],int right[]){
    if(inside(x,left,right)){
        return sum[x]+delta[x]*figure(x);//1 
    }

    push_down(x);
    int lre=0;

    for(int i=0;i<8;i++){
        int ch=children[x][i];
        if(okay(ch,left,right)){
            lre+=query_sum(ch,left,right);//2
        }
    }
    push_up(x);
    return lre;
}

最终代码实现

下面这个代码是以三维数组为例,但更改参数以后,它可以存储任意维度的线段树,前提是数组开得下!

#include<iostream>
#include<set> 

using namespace std;

//这里以一个10*10*10,支持区间求和,加值,求最大的线段树为例
//值域在0~1e9之间 

struct node{
    int left[4];
    int right[4];
};

//参照一维线段树,若数组有x个元素,理论上应开足4X的空间 
int cnt=1;
int children[4005][9];
//这里管理孩子的数组第二维参考了状态压缩 
int l[4005][4],r[4005][4];
/*
这里的l和r不同于一般意义上的左右,我们用一组对角的坐标表示这个节点
在高维空间内管辖的范围,l是最靠近原点的点的坐标,r是最远离原点的点的坐标 
*/
int sum[4005],delta[4005],ma[4005];
//虽然维度变高,但是线段树依然用线性方式存储,因此这些数组保持一维 

int wd=3;
int n=10; 
int arr[15][15][15];

bool operator <(node a,node b){
    return a.left[1]<b.left[1];
}//这个东西没有实际意义,只是为了让下文的set可以正常运行

bool operator ==(node a,node b){
    return false;
}//这个东西没有实际意义,只是为了让下文的set可以正常运行

bool one_point(int left[],int right[]){
    for(int i=1;i<=wd;i++){
        if(left[i]!=right[i]){
            return false;
        } 
    }
    return true;
}

int figure(int x){
    int lre=1;
    for(int i=1;i<=wd;i++){
        lre*=(r[x][i]-l[x][i]+1);
    }
    return lre;
}//管辖节点个数 

void push_up(int x){
    ma[x]=-1;
    sum[x]=0;
    for(int i=0;i<8;i++){
        if(children[x][i]==0){
            continue;
        }
        sum[x]+=sum[children[x][i]]+delta[children[x][i]]*figure(children[x][i]);
        ma[x]=max(ma[x],ma[children[x][i]]+delta[children[x][i]]);
    }
} 

void push_down(int x){
    for(int i=0;i<8;i++){
        delta[children[x][i]]+=delta[x];
    }
    delta[x]=0;
    push_up(x);
}

void build(int x,int left[],int right[]){

    for(int i=1;i<=wd;i++){
        l[x][i]=left[i];
        r[x][i]=right[i];
    }

    if(one_point(left,right)){
        //侦测到管辖范围已经成为一个点
        sum[x]=ma[x]=arr[left[1]][left[2]][left[3]];
        delta[x]=0;
        return; 
    }

    int mid[4];
    int nleft[4],nright[4];
    for(int i=1;i<=wd;i++){
        mid[i]=left[i]+right[i]>>1;
    }

    /*
    接下来是高维线段树中最为困难的部分, 产生它的孩子
    基本思路:受到细胞分裂启发,我们一维一维地将先前的
    新产生数组left和right分裂,如果这一维left[i]==right[i], 
    则不分裂,依托set和上文中的node实现 
    */

    multiset<node> st;
    multiset<node> st2;
    //必须使用multiset,否则分裂出来的孩子会莫名其妙地失踪 
    node llre;
    for(int i=1;i<=wd;i++){
        llre.left[i]=left[i];
        llre.right[i]=right[i];
    }
    st.insert(llre);
    for(int i=1;i<=wd;i++){
        if(left[i]==right[i]){
            continue;
        }
        for(multiset<node>::iterator j=st.begin();j!=st.end();j++){
            node tmp=*j,lre;
            for(int l=1;l<=wd;l++){
                lre.left[l]=tmp.left[l];
                lre.right[l]=tmp.right[l];
            }
            lre.left[i]=left[i];
            lre.right[i]=mid[i];
            st2.insert(lre);
            lre.left[i]=mid[i]+1;
            lre.right[i]=right[i];
            st2.insert(lre);
            //分裂 
        }
        st.clear();
        for(multiset<node>::iterator j=st2.begin();j!=st2.end();j++){
            st.insert(*j);
            //show_node(*j);
        }
        st2.clear();
    }
    //循环过后,此时该节点的孩子全部存储在st中 

    for(multiset<node>::iterator i=st.begin();i!=st.end();i++){
        node tmp=*i;
        int ll[4],rr[4],ind[4],ch=0;
        for(int l=1;l<=wd;l++){
            ll[l]=tmp.left[l];
            rr[l]=tmp.right[l];
        }
        for(int j=1;j<=wd;j++){
            if(ll[j]==left[j]){
                ch=(ch|(0<<(j-1)));
            }else{
                ch=(ch|(1<<(j-1)));
            }
        }
        children[x][ch]=cnt++;
        build(children[x][ch],ll,rr);
    }
    push_up(x);
}

bool inside(int x,int left[],int right[]){
    for(int i=1;i<=wd;i++){
        if(!(left[i]<=l[x][i]&&r[x][i]<=right[i])){
            return false;
        }
    }
    return true;
}

bool okay(int x,int left[],int right[]){
    for(int i=1;i<=wd;i++){
        if(!(l[x][i]<=right[i]&&left[i]<=r[x][i])){
            return false;
        }
    }
    return true;
}

void plu(int x,int left[],int right[],int val){
    if(inside(x,left,right)){
        delta[x]+=val;
        return;
    }

    for(int i=0;i<8;i++){
        int ch=children[x][i];
        if(okay(ch,left,right)){
            plu(ch,left,right,val);
        }
    }

    push_up(x);
}

int query_sum(int x,int left[],int right[]){
    //实际上不同的查询中只有该函数中标注释的1,2两行要做改动 
    if(inside(x,left,right)){
        return sum[x]+delta[x]*figure(x);//1 
    }

    push_down(x);
    int lre=0;

    for(int i=0;i<8;i++){
        int ch=children[x][i];
        if(okay(ch,left,right)){
            lre+=query_sum(ch,left,right);//2
        }
    }
    push_up(x);
    return lre;
}

int query_max(int x,int left[],int right[]){
    //实际上不同的查询中只有该函数中标注释的1,2两行要做改动 
    if(inside(x,left,right)){
        return ma[x]+delta[x];//1
    }
    push_down(x);
    int lre=0;

    for(int i=0;i<8;i++){
        int ch=children[x][i];
        if(okay(ch,left,right)){
            lre=max(lre,query_max(ch,left,right));//2
        }
    }
    push_up(x);
    return lre;
}

void input(int a,int b,int c){
    if(a==0){
        for(int i=1;i<=n;i++){
            input(i,0,0);
        }
    }else if(b==0){
        for(int i=1;i<=n;i++){
            input(a,i,0);
        }
    }else if(c==0){
        for(int i=1;i<=n;i++){
            input(a,b,i);
        }
    }else{
        cin>>arr[a][b][c];
    }
}

signed main(){

    cin>>n;
    //输入长度,存储的表格就是 n*n*n的 
    input(0,0,0);
    //由于维度不确定,这里用递归替代for循环来输入

    int ll[4]={0,1,1,1},rr[4]={0,n,n,n};
    build(0,ll,rr);
    /*
    查询的使用方法:
    A x1 y1 z1 x2 y2 z2 val 表示将这个区间统一加上val
    B x1 y1 z1 x2 y2 z2 表示查询这个区间的和
    C x1 y1 z1 x2 y2 z2 表示查询这个区间的最大值 

    */
    int q;
    cin>>q;
    while(q--){
        char str;
        int l[4],r[4],val;
        cin>>str;
        for(int i=1;i<=wd;i++){
            cin>>l[i];
        }
        for(int i=1;i<=wd;i++){
            cin>>r[i];
        }
        if(str=='A'){
            cin>>val;
            plu(0,l,r,val);
        }else if(str=='B'){
            cout<<query_sum(0,l,r)<<endl;
        }else if(str=='C'){
            cout<<query_max(0,l,r)<<endl;
        }
    }

    return 0; 
} 

/*
simple input:
2
1 2
3 4

5 6
7 8

5
A 1 1 1 2 2 2 1
B 1 1 1 2 2 2
B 2 1 2 2 2 2
A 1 1 1 2 2 1 3
C 2 1 1 2 2 2

simple output:
44
16
11
*/

这个代码的长度带上调试代码(在上文已被删除)已经达到了恐怖的 361 行,这也从反面证明了为什么一般比赛不出不固定维度的数组或高维的数组,这也让作者认识到了:

下次二维问题还是写树状数组吧......