浅谈树状数组

· · 算法·理论

本文章同步发表在博客园。

树状数组是一种好吃的东西,建议大家学了以后多吃吃喵。(什

算法详解

定义

树状数组是一种支持单点修改可差分区间查询,并且码量低常数还小的非常赞的数据结构。

原理

首先我们都知道,任何一个数都可以表示成至多 \log2 的次幂的和,比如 13 = 2^3 + 2^2 + 2^1

这种思想就是树状数组的核心。

树状数组中,利用 \text{lowbit} 的性质,使得任意一个前缀和可以被 \log 个长度为 2 的次幂的区间和表示出来,并且任意一个位置都被最多 \log 个区间包括,那我们就可以实现 O(\log n) 的单点修改或区间查询了。

::::info[\text{lowbit} 是什么?] 它表示一个数在二进制表示下的最低位的值,比如 6\text{lowbit}28\text{lowbit}8。在代码实现中,我们可以用 x&-x 快速计算数字 x\text{lowbit} 值。 ::::

树状数组中具体是如何处理的呢?它定义了 c_i 表示以 i 为右端点且长度为 \text{lowbit}(i) 的区间和。那么比如长度为 7 的前缀和 sum_7 = c_7 + c_6 + c_4(分别为区间 [7,7][5,6][1,4])。

定义懂了,那么要单点修改或者查询前缀和分别该怎么做呢?

首先说单点修改。在修改第 x 个位置的时候,因为 c_i 存储的是右端点为 i 的区间和,所以修改时只需要找所有右端点 \ge x 且包含 x 的区间。由于区间长度为 \text{lowbit} 值,因此不难发现,所有 x < i < x+\text{lowbit}(x)i 对应的 c_i 的区间一定都无法包含 x,因为它的 \text{lowbit} 值不够长。除了 c_x,接下来一个包含 x 的就是 c_{x + \text{lowbit(x)}} 了,因为 x + \text{lowbit}(x)\text{lowbit} 至少是 \text{lowbit}(x) 的两倍,那么肯定是能够够得到 x 这一边的。以此类推,就这样不断找,由于每次 \text{lowbit} 值都在成至少两倍的速度增长,因此时间复杂度是 O(\log n) 的。

再来说查询前缀和。假设你要查询的是 [1,x] 这一段区间的前缀和,那么首先肯定要加上 c_x,这样你就已经统计了区间 [x-\text{lowbit}(x) + 1 , x] 的和,接下来就要从右端点 x-\text{lowbit} 找起了。于是就不断减 \text{lowbit} 就可以了,时间复杂度也是 O(\log n) 的。

实现

void upd(int x,int k){
    while(x<=n)
        c[x]+=k,x+=x&-x;
    return;
}//单点修改
int ask(int x){
    int sum=0;
    while(x)sum+=c[x],x-=x&-x;
    return sum;
}//前缀和查询
int query(int l,int r){
    return ask(r)-ask(l-1);
}//区间查询

运用

求逆序对数量

使用权值树状数组,维护每个数出现的次数,就像桶一样。顺序遍历 i1n,在统计 a_i 对逆序对的贡献时,查询目前树状数组中存储的大于 a_i 的数有多少个即可。处理完贡献后 upd(a[i],1);a_i 也存进这个用树状数组维护的桶里即可。

以下代码中假设 a 是一个长度为 n 的排列。其实核心代码就两行。

for(int i=1;i<=n;i++)
    Ans+=ask(n)-ask(a[i]),upd(a[i],1);

求区间最值

求区间最值其实一般不会用树状数组来着,可以考虑用线段树,不带修可以用 ST 表。

但是硬要用树状数组还是可以用的,只是时间复杂度有两只 \log 跑不快。

单点修改时,枚举到每个 x 都需要完整遍历一遍所有小于 \text{lowbit}(x)\text{lowbit} 以确保更新到所有包含 x 的区间;区间查询时,因为取最值无法差分处理,因此只能从 r 慢慢倒退回 l,能取完整区间就一次取完,不能取就只取单个值,慢慢往回推直到与 l 相遇即可。

以下代码中以求 \max 举例。

void upd(int x,LL k) {
    a[x]=k;
    while(x<=n){
        c[x]=a[x];
        for(int i=1;i<(x&-x);i<<=1)
            c[x]=max(c[x],c[x-i]),
        x+=(x&-x);
    }return;
}//单点修改
LL ask(int l,int r){
    LL res=a[r];
    while(l<=r){
        if(r-(r&-r)+1<l)res=max(res,a[r--]);
        else res=max(res,c[r]),r-=(r&-r);
    }return res;
}//区间查询

区间修改单点查询

你以为这个不能用树状数组维护?只能用线段树?当然不是啦!

你可以用树状数组维护原数列的差分数组,区间修改时调用 upd(l,k);upd(r+1,-k);,查询的时候 ask(x) 就可以直接得到原数列中 x 的值啦。比线段树好写多了呢!

区间修改区间查询

哇这个真的很神。那以后估计是没人用线段树搞这个了,它被抛弃咯(

依然是维护原数列的差分数组,求原序列的区间和等价于求原序列的前缀和。原序列是差分数组的前缀和,所以我们需要求的是差分数组的前缀和的前缀和。

假设我们要求原序列的前 x 项之和,那就是在差分数组上求前 x 种前缀和之和,其中,差分数组的第 1 项被计算了 x 次,第 2 项被计算了 x-1 次,以此类推,第 x 项被计算了 1 次。

记差分数组第 i 项为 d_i,那么这个式子可以写成:

\sum_{i=1}^{x} d_i \times (x-i+1)

拆括号变形:

\sum_{i=1}^{x} d_i \times (x+1) - \sum_{i=1}^{x} d_i \times i

注意到第一项可以直接通过 d_i 的树状数组求出,而第二项需要维护 d_i \times i 的树状数组。

权值树状数组上二分求第 k

与倍增的思想很类似,也很像求 LCA 的过程。从 2^{\log n} 枚举到 2^{0},能加则加。

LL Kth(LL k){
    LL res=0,cnt=0;
    for(int i=20;i>=0;i--){
        if((res+(1<<i))>Mx)continue;
        if(cnt+c[res+(1<<i)]<k)
            res+=(1<<i),cnt+=c[res];
    }
    return res+1;
}

例题讲解

P1168 中位数

其实这个东西也可以用堆做。

考虑维护权值树状数组,但是需要先离散化。

如果当前被加入树状数组的数字恰好是奇数个,也就是我们要求解答案的位置了,就使用二分,二分得出最小且出现个数 > \lfloor \frac{i}{2} \rfloor 的数字,但是由于我们加了离散化,二分出的结果是离散化后的值,因此还需要额外开个数组映射离散化后的值到原值。

时间复杂度好像有两只 \log 但是问题不大。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e5+5;
LL n,a[N],cnt,c[N],bel[N];
map<LL,LL> Ls;
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
void upd(int x,LL k){while(x<=cnt)c[x]+=k,x+=x&-x;return;}
LL ask(LL x){LL sum=0;while(x)sum+=c[x],x-=x&-x;return sum;}
int main(){
    n=read();
    for(int i=1;i<=n;i++)a[i]=read(),Ls[a[i]]=0;
    for(auto &u:Ls)u.se=(++cnt),bel[cnt]=u.fr;
    for(int i=1;i<=n;i++)a[i]=Ls[a[i]];
    for(int i=1;i<=n;i++){
        upd(a[i],1);
        if(i%2==0)continue;
        int l=1,r=cnt,res=0;
        while(l<=r){
            int mid=(l+r)/2;
            LL sum=ask(mid);
            if(sum>i/2)res=mid,r=mid-1;
            else l=mid+1;
        }cout<<bel[res]<<"\n";
    }
    return 0;
}

P2345 MooFest G

首先根据坐标点排序。

根据一个点来算,听力左边比它小的和右边比它小的,分开跑两次树状数组。

树状数组不仅要维护听力比它差的奶牛个数,还要维护它们的坐标点之和。于是还得开两个树状数组。

第一次是正序遍历,枚举到一个 i,取到左边听力比它小的奶牛个数 cnt 和坐标点之和 sum,那么这一轮对答案的贡献就是 v_i ( cnt \times x_i - sum )

第二次是倒序遍历,枚举到一个 i,取右边左边听力比它小的奶牛个数 cnt 和坐标点之和 sum,那么这一轮对答案的贡献则是 v_i ( sum - cnt \times x_i )

记得开 long long

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 2e4+5;
struct cow{LL id,val;}a[N];
LL n,cnt,c[N],cc[N],Ans,Mx;
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
bool cmp(cow c1,cow c2){return c1.id<c2.id;}
void upd(int x,LL k){while(x<=Mx)c[x]+=k,x+=x&-x;return;}
LL ask(LL x){LL sum=0;while(x)sum+=c[x],x-=x&-x;return sum;}
void Upd(int x,LL k){while(x<=Mx)cc[x]+=k,x+=x&-x;return;}
LL Ask(LL x){LL sum=0;while(x)sum+=cc[x],x-=x&-x;return sum;}
int main(){
    n=read();
    for(int i=1;i<=n;i++)
        a[i].val=read(),a[i].id=read();
    sort(a+1,a+n+1,cmp);
    for(int i=1;i<=n;i++)
        Mx=max(Mx,a[i].val);
    for(int i=1;i<=n;i++){
        LL cnt=Ask(a[i].val);
        Ans+=(cnt*a[i].id-ask(a[i].val))*a[i].val;
        upd(a[i].val,a[i].id);
        Upd(a[i].val,1);
    }for(int i=1;i<=Mx;i++)c[i]=0,cc[i]=0;
    for(int i=n;i>=1;i--){
        LL cnt=Ask(a[i].val-1);
        Ans+=(ask(a[i].val-1)-cnt*a[i].id)*a[i].val;
        upd(a[i].val,a[i].id);
        Upd(a[i].val,1);
    }cout<<Ans<<"\n";
    return 0;
}

P3531 LIT-Letters

交换相邻两个数是什么概念啊,你说得对,逆序对!但是是求什么东西的逆序对呢?

找出 a 串中第 x 次出现的字符 c_1 与其在 b 中第 x 次出现的相同的字符 c_2 的位置,计入 p 数组。然后对这个 p 求逆序对就行了。

为什么?因为我们要通过交换使得 a 串对应的 p 数组与 b 串的 p 数组相同,但 a 串与 a 串自身处理出来的 p 数组就是生序排列,所以交换的次数就等同于 b 串与 a 串处理出来的 p 的逆序对个数啦。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e6+5;
LL n,a[N],Ans,c[27][N],cnt[27],d[27],p[N];
string s,t;
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
void upd(int x,int k){while(x<=n)p[x]+=k,x+=x&-x;return;}
int ask(int x){int sum=0;while(x)sum+=p[x],x-=x&-x;return sum;}
int main(){
    n=read();
    cin>>s>>t;s=" "+s,t=" "+t;
    for(int i=1;i<=n;i++){
        int x=s[i]-'A'+1;
        c[x][++cnt[x]]=i;
    }for(int i=1;i<=n;i++){
        int x=t[i]-'A'+1;
        a[i]=c[x][++d[x]];
    }for(int i=1;i<=n;i++){
        Ans+=ask(n)-ask(a[i]-1);
        upd(a[i],1);
    }cout<<Ans<<"\n";
    return 0;
}

P4392 Sound 静音问题

发现这题可以直接套用上面提到的【求区间最值】的方式做。

但是不太对啊,这题的区间是一个滑动窗口的形式,是不是可以更简单地解决呢?

当然可以。同步维护权值树状数组,滑动窗口,某个值被移进来了就让其对应数值个数 +1,某个值被移出去了就让其对应数值个数 -1

这个时候你就要问了我们怎么求区间最大和最小呢?很简单,注意到可以用上面提到的【权值树状数组上二分求第 k 小】的方法,由于区间长度一定为 m,于是我们查分别查 \text{Kth}(1)\text{Kth}(m) 就可以啦。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e6+5;
LL n,m,can,Mx,a[N],c[N],Anscnt;
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
void upd(int x,LL k){while(x<=Mx)c[x]+=k,x+=x&-x;return;}
LL ask(LL x){LL sum=0;while(x)sum+=c[x],x-=x&-x;return sum;}
LL Kth(LL k){
    LL res=0,cnt=0;
    for(int i=20;i>=0;i--){
        if((res+(1<<i))>Mx)continue;
        if(cnt+c[res+(1<<i)]<k)
            res+=(1<<i),cnt+=c[res];
    }
    return res+1;
}
int main(){
    n=read(),m=read(),can=read();
    for(int i=1;i<=n;i++)
        a[i]=read()+1,Mx=max(Mx,a[i]);
    for(int i=1;i<=n;i++){
        upd(a[i],1);
        if(i-m>=1)upd(a[i-m],-1);
        if(i<m)continue;
        LL D=Kth(m)-Kth(1);
        if(D<=can){cout<<i-m+1<<"\n";Anscnt++;}
    }if(!Anscnt)cout<<"NONE\n";
    return 0;
}

CF703D Mishka and Interesting sum

好玩的题。

不难发现,如果题目要求的是区间内出现奇数次的数字的异或和,那么这个题目直接用前缀异或和就可以解决了,因为偶数次会抵消。

但是现在要求的却是偶数次的,所以我们需要再异或上该区间内不重复的数字异或和。

那现在的问题就转化为,求一个区间内去重后所有数字的异或和。

发现这个东西非常莫队,但是数据范围有 10^6 啊,想冲过去得卡常吧?算了,还是考虑考虑如何用树状数组做吧。

首先得把所有数字放进树状数组,维护当前每个位置的值,便于快速查找区间异或和。

离线,把所有询问按照右端点升序排序,然后依次处理每个操作。

遍历 now_R 为当前遍历到的右端点。如果遇到某个值 x 在当前遍历到的位置中出现了超过 2 次,就把它的上一次出现在树状数组里对应位置修改为 0,保留这次出现。这样你就完美解决了去重。更新 now_R 至当前遍历到的节点,往后递推,然后查询本次询问的区间即可。

问题来了,怎么保证出现新的之后把旧的抹去不会影响后面统计答案?因为你已经按照右端点升序排序,所以不用担心后面出现包含旧的却不包含新的的情况。如果包含旧的,一定包含新的,因为右端点是升序排序的啦!

最后还原查询顺序输出即可,整体还是非常简单的。

注意需要不断维护每种数字最后一次出现的位置。

#include<bits/stdc++.h>
#define LL long long
#define UInt unsigned int
#define ULL unsigned long long
#define LD long double
#define pii pair<int,int>
#define pLL pair<LL,LL>
#define pDD pair<LD,LD>
#define fr first
#define se second
#define pb push_back
#define isr insert
using namespace std;
const int N = 1e6+5;
struct query{LL l,r,ans,id;}q[N];
LL n,m,a[N],s[N],c[N],now;
map<LL,LL> lst;
LL read(){
    LL su=0,pp=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')pp=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){su=su*10+ch-'0';ch=getchar();}
    return su*pp;
}
bool cmp(query q1,query q2){return q1.r<q2.r;}
bool cmp2(query q1,query q2){return q1.id<q2.id;}
void upd(int x){LL p=a[x];while(x<=n)c[x]^=p,x+=x&-x;return;}
LL ask(int x){LL sum=0;while(x)sum^=c[x],x-=x&-x;return sum;}
int main(){
    n=read();
    for(int i=1;i<=n;i++)
        a[i]=read(),s[i]=s[i-1]^a[i],upd(i);
    m=read();
    for(int i=1;i<=m;i++)
        q[i].l=read(),q[i].r=read(),q[i].id=i;
    sort(q+1,q+m+1,cmp);
    for(int i=1;i<=m;i++){
        while(now<q[i].r){
            now++;
            if(lst[a[now]]!=now){
                if(lst[a[now]])
                    upd(lst[a[now]]),a[lst[a[now]]]=0;
                lst[a[now]]=now;
            }
        }q[i].ans=ask(q[i].r)^ask(q[i].l-1);
        q[i].ans^=s[q[i].r]^s[q[i].l-1];
    }sort(q+1,q+m+1,cmp2);
    for(int i=1;i<=m;i++)cout<<q[i].ans<<"\n";
    return 0;
}

简单总结

树状数组,是一种好用的 \log 时间复杂度数据结构,其可以解决单点修改、差分区间查询操作,还可以通过维护差分数组来区间修改单点查询。还有很多用处,比如维护当前出现的数字中第 k 小的、维护区间最值,区修区查有时候也可以用它解决哦!用处极为广泛,常被用作一种解题的工具,因为容易理解、实现简单、速度高效,经常在各大题目里都能看见它的身影。总之,它是一种非常棒的数据结构啦!> <

码这么多字也不容易,还麻烦你留个赞支持一下,真是太感谢啦!