树状数组
P3374 【模板】树状数组 1
树状数组
树状数组,是一个查询和修改复杂度都为log(n)的数据结构。主要用于查询任意两位之间的所有元素之和,但是每次只能修改一个元素的值。
经过简单修改可以在log(n)的复杂度下进行范围修改,但是这时只能查询其中一个元素的值(如果加入多个辅助数组则可以实现区间修改与区间查询)。
这种数据结构(算法)并没有C++和Java的库支持,需要自己手动实现。
在Competitive Programming的竞赛中被广泛的使用。树状数组和线段树很像,但能用树状数组解决的问题,基本上都能用线段树解决,而线段树能解决的树状数组不一定能解决。
相比较而言,树状数组效率要高很多。
树状数组概念
假设数组a[1..n],那么查询a[1]+...+a[n]的时间是log级别的,而且是一个在线的数据结构,支持随时修改某个元素的值,复杂度也为log级别。
令这棵树的结点编号为C1,C2...Cn。令每个结点的值为这棵树的值的总和,那么容易发现:
C1 = A1
C2 = A1 + A2
C3 = A3
C4 = A1 + A2 + A3 + A4
C5 = A5
C6 = A5 + A6
C7 = A7
C8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8
...
C16 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8 + A9 + A10 + A11 + A12 + A13 + A14 + A15 + A16
这里有一个有趣的性质:
设节点编号为x,那么这个节点管辖的区间为2^k(其中k为x二进制末尾0的个数)个元素。因为这个区间最后一个元素必然为Ax,
所以很明显:Cn = A(n – 2^k + 1) + ... + An
算这个2^k有一个快捷的办法。
lowbit这个函数的功能就是求某一个数的二进制表示中最低的一位1,举个例子,x = 6,它的二进制为110,那么lowbit(x)就返回2,因为最后一位1表示2。
定义一个函数如下即可:
int lowbit(int x)
{
return x&(x^(x-1));
}
利用机器补码特性,也可以写成
int lowbit(int x)
{
return x&-x;
}
当想要查询一个SUM(n)(求a[n]的和),可以依据如下算法即可:
step1: 令sum = 0,转第二步; step2: 假如n <= 0,算法结束,返回sum值,否则sum = sum + Cn,转第三步; step3: 令n = n – lowbit(n),转第二步。 可以看出,这个算法就是将这一个个区间的和全部加起来,为什么是效率是log(n)的呢?以下给出证明:
n = n – lowbit(n)这一步实际上等价于将n的二进制的最后一个1减去。而n的二进制里最多有log(n)个1,所以查询效率是log(n)的。
那么修改呢,修改一个节点,必须修改其所有祖先,最坏情况下为修改第一个元素,最多有log(n)的祖先。
所以修改算法如下(给某个结点i加上x):
step1: 当i > n时,算法结束,否则转第二步; step2: Ci = Ci + x, i = i + lowbit(i)转第一步。 i = i +lowbit(i)这个过程实际上也只是一个把末尾1补为0的过程。
对于数组求和来说树状数组简直太快了!
注:
求lowbit(x)的建议公式:
lowbit(x):=x&-x;
lowbit(x):=x&(x^(x-1));
lowbit(x)即为2^k的值。
充分性
很容易知道C8表示A1~A8的和,但是C6却是表示A5~A6的和,为什么会产生这样的区![]别的呢?或者说发明她的人为什么这样区别对待呢?
答案是,这样会使操作更简单!看到这相信有些人就有些感觉了,为什么复杂度被log了呢?
可以看到,C8可以看作A1~A8的左半边和+右半边和,而其中左半边和是确定的C4,右半边其实也是同样的规则把A5~A8一分为二……继续下去都是一分为二直到不能分树状数组巧妙地利用了二分,树状数组并不神秘,关键是巧妙! (https://s1.ax1x.com/2018/08/09/PyKpDg.png)
#include<iostream>
#include<cmath>
#include<algorithm>
#include<string>
#include<cstring>
#include<cstdio>
#include<cstdlib>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<vector>
using namespace std;
const int N = 5*1e5 +100;
int arr[N],c[N],n;
int lowbit(int x)
{
return x & (-x);
}
void update(int id,int x)
{
while(id<=n)
{
c[id]+=x;
id += lowbit(id);
}
}
int query(int id)
{
int ans = 0;
while(id!=0)
{
ans+=c[id];
id-=lowbit(id);
}
return ans;
}
int main()
{
int m;
cin >> n >> m;
for(int i=1;i<=n;i++)
{
cin>>arr[i];
update(i,arr[i]);
}
while(m--)
{
int op,x,y;
cin >> op >> x >> y;
if(op==1)
update(x,y);
if(op==2)
cout << query(y) - query(x-1) << endl;
}
return 0;
}