【数据结构】树状数组

· · 算法·理论

众所周知,要在维护前缀和的同时处理单点修改,我们可以使用线段树。只是,它的码量非常大,并且很难调试。而如果仅维护前缀和数组,又无法高效地处理修改。

这时,我们可以选择使用一种数据结构——树状数组。虽然它的时间复杂度整体较线段树并没有优化,但其码量相比一般的线段树可以说是惊人地小。

约定

本文中的代码仅作为示范,请根据题目实际情况选择合适的数据类型与数组大小

简介

树状数组是一种用于处理单点修改与区间查询的数据结构。它相比线段树适用范围更小,主要原因是

  1. 树状数组仅维护前缀和,不能处理无法通过前缀和查询区间的区间查询操作(如区间最大值,但前缀最大值是可以的)。
  2. 树状数组仅支持处理单点修改与区间查询。

也正是更小的适用范围,换来了相较线段树的优化空间。

::cute-table 40 < < < < < < <
21 < < < 19 < < <
9 < 12 < 11 < 8 <
1 8 3 9 7 4 2 6

这是一棵维护区间和的线段树,在查询一般的区间和时,其中的每个节点都是有可能访问到的。但在查询前缀和时并不是这样。例如,现在我们要查询区间 [1,5],[1,6],[1,7] 的区间和,它们访问到的节点如下所示。

::cute-table 40 < < < < < < <
21 < < < 19 < < <
9 < 12 < 11 < 8 <
1 8 3 9 7 4 2 6
::cute-table 40 < < < < < < <
21 < < < 19 < < <
9 < 12 < 11 < 8 <
1 8 3 9 7 4 2 6
::cute-table 40 < < < < < < <
21 < < < 19 < < <
9 < 12 < 11 < 8 <
1 8 3 9 7 4 2 6

可以发现,最底层的第 6 个节点处于十分尴尬的位置——在查询前缀和时,它根本不会被访问到。于是我们可以将所有这样冗余的节点都去掉。

::cute-table 40 < < < < < < <
21 < < < —— < < <
9 < —— < 11 < —— <
1 —— 3 —— 7 —— 2 ——

这就是树状数组的结构。它通常用一维数组存储,每个位置存储以该位置结尾的区间和

::cute-table 1 9 3 21 7 11 2 40

但是,存储好了,我们如何对其进行查询与更新呢?

单点加区间和

标准的树状数组就是用于处理单点加区间和问题的。

查询

树状数组的操作均依赖 \text{lowbit}。它指的是一个数的二进制表示中最低位的 1 对应的值。例如,\text{lowbit}(12_{10})=\text{lowbit}(1100_2)=100_2=8_{10}

在计算机中,通常使用如下的方式求解它。

int lowbit(int x)
{
    return x&-x;
}

具体求解方式涉及到补码与原码的关系,这里不展开讲解。

对于树状数组 b,查询区间 [1,\text{pos}] 的前缀和的代码如下。

int query(int pos)
{
    int ans=0;
    while(pos>=1) ans+=b[pos],pos-=lowbit(pos);
    return ans;
}

这看起来有点莫名其妙,树状数组和 \text{lowbit} 有什么关系呢?

我们要知道,树状数组本质上就是对前缀和的二进制分解。例如,要查询区间 [1,5] 的前缀和,我们先把 5 转换成二进制,即 101_2。依照这个二进制的每个 1 对应的值分别作为区间长度,可以把区间 [1,5] 拆成 [1,4],[5,5]

::cute-table 40 < < < < < < <
21 < < < —— < < <
9 < —— < 11 < —— <
1 —— 3 —— 7 —— 2 ——

可以发现,对于当前位置 \text{now}\text{lowbit}(\text{now}) 就是树状数组当前节点对应的区间长度,也就可以知道 \text{now}-\text{lowbit}(\text{now}) 就是下一个需要访问的区间的结尾位置。还记得树状数组中的元素是怎么存放的吗?

它通常用一维数组存储,每个位置存储以该位置结尾的区间和

没错,这意味着我们又可以得到下一个区间的区间和 b_{\text{now}-\text{lowbit}(\text{now})}。以此类推,就能实现前缀和的查询。

请注意树状数组必须从下标 1 开始存储,因为 \text{lowbit}(0)=0,会导致死循环。

更新

首先看看 b 和原数组 a 有什么关系。有了上面推导的过程,很容易知道

b_i=\sum^i_{j=i-\text{lowbit}(i)+1} a_j

现在假设 a_\text{pos} 增加了 k。我们需要修改所有 b_x,使得 x-\text{lowbit}(x)+1\le \text{pos}\le x。可以使用如下代码。

void update(int pos,int k)
{
    for(int i=pos;i<=n;i++)
        if(i-lowbit(i)+1<=pos) b[i]+=k;
    return;
}

这样做的时间复杂度是 \mathcal{O}(n)。如何优化呢?可以发现,在遍历中,许多位置并不合法,却仍然要尝试一遍。下一个位置的索引是否与当前位置的索引有关系呢?不妨以 \text{pos}=13 的情况为例,列出合法的位置为 13,14,16,32,\dots,转换成二进制是 1101_2,1110_2,10000_2,100000_2,\dots。看出规律了吗?对于当前位置 \text{now},下一个位置就是 \text{now}+\text{lowbit}(\text{now})

void update(int pos,int k)
{
    while(pos<=n) b[pos]+=k,pos+=lowbit(pos);
    return;
}

也可以通过数学证明的方式严格证明下一个位置是 \text{now}+\text{lowbit}(\text{now}),这里不展开讲解。

在建立树状数组时,我们只需要进行 n 次单点更新即可。这样,我们就写出了一个树状数组。

#include<bits/stdc++.h>
using namespace std;
int n,m,a[500005],b[500005];
int lowbit(int x)
{
    return x&-x;
}
void update(int pos,int k)
{
    while(pos<=n) b[pos]+=k,pos+=lowbit(pos);
    return;
}
int query(int pos)
{
    int ans=0;
    while(pos>=1) ans+=b[pos],pos-=lowbit(pos);
    return ans;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i],update(i,a[i]);
    while(m--)
    {
        int op,x,y;
        cin>>op>>x>>y;
        if(op==1) update(x,y);
        else cout<<query(y)-query(x-1)<<endl;
    }
    return 0;
}

这里可以看出来,树状数组在代码实现上好像和树没有半毛钱关系。它之所以被称为“树状数组”,主要还是因为在逻辑上它是一颗树。

区间加单点值

树状数组也可以处理区间加单点值的问题,不过这需要使用差分。 ::::info[知识点-差分]{open} 差分是前缀和的逆运算,不过它们都用来处理区间问题。具体地,差分常用来处理多次区间修改,并在修改操作完成后通过前缀和还原修改后的序列。单纯的差分一般用来解决离线区间修改问题。

对于序列 a,它的差分数组 d 与其关系如下。

d_i=a_i-a_{i-1},a_0=0

说白了,差分数组存储的就是“当前位置的元素相较于前一个位置的元素的变化值”。这样,如果我们要给区间 [l,r] 中的元素都加上 k,我们只需要

d_l\leftarrow d_l+k,d_{r+1}\leftarrow d_{r+1}-k

即可。

与前缀和类似,同样有二维乃至多维的差分。具体见 OI-Wiki,下文也会用到二维前缀和与二维差分。 :::: 显然,要实现区间加,我们可以维护一个原数组的差分数组。但由于查询操作和修改操作的顺序并不一定,不能仅使用差分,因为查询操作需要对差分数组求前缀和。求前缀和?这恰好是树状数组的功能。并且树状数组刚好还可以实现单点更新,与差分数组的修改操作不谋而合。于是方案呼之欲出——我们使用树状数组维护差分数组即可。

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i],update(i,a[i]-a[i-1]);
    while(m--)
    {
        int op,x,y,k;
        cin>>op>>x;
        if(op==1)
        {
            cin>>y>>k;
            update(x,k),update(y+1,-k);
        }
        else cout<<query(x)<<endl;
    }
    return 0;
}

区间加区间和

P3372 的标签虽说只有线段树,但我们也是可以用树状数组通过这道题的。利用树状数组解决区间加区间和问题同样需要通过使用差分将问题转化为维护与查询前缀和。

由差分数组的定义,我们可以知道

a_i=\sum^i_{j=1} d_j

把它代入到 \sum^r_{i=1}a_i 中,有

\sum^r_{i=1}a_i=\sum^r_{i=1}\sum^i_{j=1}d_j=\sum^r_{i=1}d_i\times(r-i+1)=(r+1)\sum^r_{i=1}d_i-\sum^r_{i=1}i\times d_i

于是我们只要用两个树状数组分别维护 d_i,i\times d_i 的前缀和就行。

#include<bits/stdc++.h>
using namespace std;
int n,m;
long long a[100005],b[100005],c[100005];
int lowbit(int x)
{
    return x&-x;
}
void update1(int pos,long long k)
{
    while(pos<=n) b[pos]+=k,pos+=lowbit(pos);
    return;
}
void update2(int pos,long long k)
{
    while(pos<=n) c[pos]+=k,pos+=lowbit(pos);
    return;
}
long long query1(int pos)
{
    long long ans=0;
    while(pos>=1) ans+=b[pos],pos-=lowbit(pos);
    return ans;
}
long long query2(int pos)
{
    long long ans=0;
    while(pos>=1) ans+=c[pos],pos-=lowbit(pos);
    return ans;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i],update1(i,a[i]-a[i-1]),update2(i,(long long)i*(a[i]-a[i-1]));
    while(m--)
    {
        int op,x,y;
        cin>>op>>x>>y;
        if(op==1)
        {
            long long k;
            cin>>k;
            update1(x,k),update1(y+1,-k),update2(x,(long long)x*k),update2(y+1,-(long long)(y+1)*k);
        }
        else cout<<(y+1)*query1(y)-query2(y)-x*query1(x-1)+query2(x-1)<<endl;
    }
    return 0;
}

二维树状数组

二维树状数组,也被称作树状数组套树状数组。它的作用很简单,就是维护与查询二维前缀和。P4514 是它的模板题。

由于单点加实际只是矩阵加的一个子问题,我们直接考虑矩阵加矩阵和。首先来看看二维前缀和是怎么实现的。类似于一维,二维前缀和 S_{i,j} 定义为

S_{i,j}=\sum_{k=1}^i\sum_{t=1}^j a_{k,t}

在建立前缀和数组时,遵循

S_{i,j}=S_{i,j-1}+S_{i-1,j}+a_{i,j}-S_{i-1,j-1}

的递推关系。查询时,左上角为 (a,b),右下角为 (c,d) 的矩阵和为

S_{c,d}-S_{a-1,d}-S_{c,b-1}+S_{a-1,b-1}

具体见 OI-Wiki。

与一维树状数组类似,我们让 b_{i,j} 表示右下角为 (i,j),高 \text{lowbit}(i),宽 \text{lowbit}(j) 的矩阵和。对应的代码也就不难写出。

void update(int x,int y,int k)
{
    for(int i=x;i<=n;i+=lowbit(i))
        for(int j=y;j<=n;j+=lowbit(j)) b[i][j]+=k;
    return;
}
int query(int x,int y)
{
    int ans=0;
    for(int i=x;i>=1;i-=lowbit(i))
        for(int j=y;j>=1;j-=lowbit(j)) ans+=b[i][j];
    return ans;
}

接下来将原问题利用差分转化。首先要知道二维差分的定义,由于在差分数组中再做一次二维前缀和得到的就是原数组,显然有

a_{i,j}=a_{i,j-1}+a_{i-1,j}+d_{i,j}-a_{i-1,j-1} d_{i,j}=a_{i,j}-a_{i,j-1}-a_{i-1,j}+a_{i-1,j-1}

那么,对左上角为 (a,b),右下角为 (c,d) 的矩阵中的所有元素加上 k 的操作可以转化为

d_{a,d+1}\leftarrow d_{a,d+1}-k\\ d_{c+1,b}\leftarrow d_{c+1,b}-k\\ d_{c+1,d+1}\leftarrow d_{c+1,d+1}+k

这便是矩阵加的处理。

对于矩阵和的查询,我们和之前一样写出 (x,y) 的二维前缀和为

&\sum^x_{i=1}\sum^y_{j=1}\sum^i_{k=1}\sum^j_{t=1}d_{k,t}\\ &=\sum^x_{i=1}\sum^y_{j=1}d_{i,j}\times(x−i+1)\times(y−j+1)\\ &=\sum^x_{i=1}\sum^y_{j=1}d_{i,j}\times(xy+x+y+1)-d_{i,j}\times i\times(y+1)-d_{i,j}\times j\times(x+1)+d_{i,j}\times i\times j \end{aligned}

接下来就简单了。

#include<bits/stdc++.h>
using namespace std;
int n,m,a[2049],b[2049][2049],c[2049][2049],d[2049][2049],e[2049][2049];
int lowbit(int x)
{
    return x&-x;
}
void update(int x,int y,int k)
{
    for(int i=x;i<=n;i+=lowbit(i))
        for(int j=y;j<=n;j+=lowbit(j)) b[i][j]+=k,c[i][j]+=k*x,d[i][j]+=k*y,e[i][j]+=k*x*y;
    return;
}
int query(int x,int y)
{
    int ans=0;
    for(int i=x;i>=1;i-=lowbit(i))
        for(int j=y;j>=1;j-=lowbit(j)) ans+=(x+1)*(y+1)*b[i][j]-(y+1)*c[i][j]-(x+1)*d[i][j]+e[i][j];
    return ans;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    char op;
    cin>>op>>n>>m;
    while(cin>>op)
    {
        int x,y,p,q,k;
        cin>>x>>y>>p>>q;
        if(op=='L') cin>>k,update(x,y,k),update(x,q+1,-k),update(p+1,y,-k),update(p+1,q+1,k);
        else cout<<query(p,q)-query(p,y-1)-query(x-1,q)+query(x-1,y-1)<<endl;
    }
    return 0;
}

它的时间复杂度此时为 \mathcal{O}(\log^2 n),空间复杂度为 \mathcal{O}(n^2)

\mathcal{O}(n) 建树

到现在,我们建立树状数组的方式是进行 n 次单点更新,时间复杂度是 \mathcal{O}(n\log n)。但其实我们也可以用 \mathcal{O}(n) 的时间复杂度建立树状数组。

之前我们一直忽略了树状数组的“树”结构。从上面的讲解中,我们可以知道,节点 i 的父节点就是 i+\text{lowbit}(i)。那么,我们不妨直接用子节点的信息维护父节点。

void build()
{
    for(int i=1;i<=n;i++)
    {
        b[i]+=a[i];
        if(i+lowbit(i)<=n) b[i+lowbit(i)]+=b[i];
    }
    return;
}

这样,所有子节点都刚好向其父节点进行一次贡献。

另一种方法,我们知道 b_i 表示的区间是 [i-\text{lowbit}(i)+1,i],所以我们可以维护一个原数组的前缀和数组,然后用前缀和数组直接计算出各个节点的值。

int pre[500005];
void build()
{
    for(int i=1;i<=n;i++) pre[i]=pre[i-1]+a[i],b[i]=pre[i]-pre[i-lowbit(i)];
    return;
}