树状数组

· · 题解

P3374 【模板】树状数组 1

树状数组

树状数组,是一个查询和修改复杂度都为log(n)的数据结构。主要用于查询任意两位之间的所有元素之和,但是每次只能修改一个元素的值。

经过简单修改可以在log(n)的复杂度下进行范围修改,但是这时只能查询其中一个元素的值(如果加入多个辅助数组则可以实现区间修改与区间查询)。

这种数据结构(算法)并没有C++和Java的库支持,需要自己手动实现。

在Competitive Programming的竞赛中被广泛的使用。树状数组和线段树很像,但能用树状数组解决的问题,基本上都能用线段树解决,而线段树能解决的树状数组不一定能解决。

相比较而言,树状数组效率要高很多。

树状数组概念

假设数组a[1..n],那么查询a[1]+...+a[n]的时间是log级别的,而且是一个在线的数据结构,支持随时修改某个元素的值,复杂度也为log级别。

令这棵树的结点编号为C1,C2...Cn。令每个结点的值为这棵树的值的总和,那么容易发现:

C1 = A1

C2 = A1 + A2

C3 = A3

C4 = A1 + A2 + A3 + A4

C5 = A5

C6 = A5 + A6

C7 = A7

C8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8

...

C16 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8 + A9 + A10 + A11 + A12 + A13 + A14 + A15 + A16

这里有一个有趣的性质:

设节点编号为x,那么这个节点管辖的区间为2^k(其中k为x二进制末尾0的个数)个元素。因为这个区间最后一个元素必然为Ax,

所以很明显:Cn = A(n – 2^k + 1) + ... + An

算这个2^k有一个快捷的办法。

lowbit这个函数的功能就是求某一个数的二进制表示中最低的一位1,举个例子,x = 6,它的二进制为110,那么lowbit(x)就返回2,因为最后一位1表示2。

定义一个函数如下即可:

int lowbit(int x)
{
    return x&(x^(x-1));
}

利用机器补码特性,也可以写成

int lowbit(int x)
{
    return x&-x;
}

当想要查询一个SUM(n)(求a[n]的和),可以依据如下算法即可:

step1: 令sum = 0,转第二步; step2: 假如n <= 0,算法结束,返回sum值,否则sum = sum + Cn,转第三步; step3: 令n = n – lowbit(n),转第二步。 可以看出,这个算法就是将这一个个区间的和全部加起来,为什么是效率是log(n)的呢?以下给出证明:

n = n – lowbit(n)这一步实际上等价于将n的二进制的最后一个1减去。而n的二进制里最多有log(n)个1,所以查询效率是log(n)的。

那么修改呢,修改一个节点,必须修改其所有祖先,最坏情况下为修改第一个元素,最多有log(n)的祖先。

所以修改算法如下(给某个结点i加上x):

step1: 当i > n时,算法结束,否则转第二步; step2: Ci = Ci + x, i = i + lowbit(i)转第一步。 i = i +lowbit(i)这个过程实际上也只是一个把末尾1补为0的过程。

对于数组求和来说树状数组简直太快了!

注:
求lowbit(x)的建议公式:

lowbit(x):=x&-x;
lowbit(x):=x&(x^(x-1));
lowbit(x)即为2^k的值。

充分性

很容易知道C8表示A1~A8的和,但是C6却是表示A5~A6的和,为什么会产生这样的区![]别的呢?或者说发明她的人为什么这样区别对待呢?

答案是,这样会使操作更简单!看到这相信有些人就有些感觉了,为什么复杂度被log了呢?

可以看到,C8可以看作A1~A8的左半边和+右半边和,而其中左半边和是确定的C4,右半边其实也是同样的规则把A5~A8一分为二……继续下去都是一分为二直到不能分树状数组巧妙地利用了二分,树状数组并不神秘,关键是巧妙! (https://s1.ax1x.com/2018/08/09/PyKpDg.png)

#include<iostream>
#include<cmath>
#include<algorithm>
#include<string>
#include<cstring>
#include<cstdio>
#include<cstdlib>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<vector>
using namespace std; 
const int N = 5*1e5 +100;
int arr[N],c[N],n;

int lowbit(int x)
{
    return x & (-x);
}
void update(int id,int x)
{
    while(id<=n)
    {
        c[id]+=x;
        id += lowbit(id);
    }
}

int query(int id)
{
    int ans = 0;
    while(id!=0)
    {
        ans+=c[id];
        id-=lowbit(id);
    }
    return ans;
}

int main()
{
    int m;
    cin >> n >> m;
    for(int i=1;i<=n;i++)
    {
        cin>>arr[i];
        update(i,arr[i]);
    }
    while(m--)
    {
       int op,x,y;
       cin >> op >> x >> y;
       if(op==1)
        update(x,y);
       if(op==2)
        cout << query(y) - query(x-1) << endl;
    }
    return 0;
}