【学习笔记】树状数组
NCC79601
2019-08-08 18:58:45
# 定义
树状数组就是为前缀数组建立的树形结构,其最朴素的应用是**单点修改,区间查询**。
对于树状数组中的每一个节点$c[i]$,其管辖$[i-lowbit(i)+1,\ i]$这一段区间,其中$lowbit(i)$计算的是$i$的二进制数位中最靠右的$1$所表示的数。例如$6_{(10)}=110_{(2)}$,那么$c[6]$管辖的即是$[5,6]$这个区间。
可以证明一个数$n$的二进制数位中最多只有$log(n)$个$1$,因此树状数组的复杂度在**最坏情况下**为$O(logn)$,而最优情况下为$O(1)$。由于树状数组的本质是一个**特殊的前缀数组**,因此空间开销为$n$,这比树状数组的稳定$O(logn)$复杂度、空间开销$4n$都要优秀。同时,树状数组的代码量远远小于线段树,因此在其应用范围内,树状数组不失为一个优秀的选择。
# 改良
由于树状数组特性,其原生只支持单点修改、区间查询;而面临**区间修改、区间查询**的情况,树状数组似乎失去了作用。实际上,树状数组也能够进行区间修改、区间查询。
考虑一个差分数组$d[]$,可以知道$a[n]=\sum_{i=1}^n d[i]$。那么:
$$s[n]=\sum_{i=1}^n a[i]=\sum_{i=1}^n \sum_{j=1}^i d[j]$$
$$=n\cdot d[1]+(n-1)\cdot d[2]+\cdots+2\cdot d[n-1] + 1\cdot d[n].$$
对这个式子进行处理,可以得到:
$$\text{原式}=n\cdot(d[1]+d[2]+\cdots+d[n])-(0\cdot d[1]+1\cdot d[2]+\cdots+(n-1)\cdot d[n]).$$
也就是说,$s[n]$可以拆成两个部分,一个是$n\cdot\sum_{i=1}^nd[i]$,另一个是$-\sum_{i=1}^n(i-1)\cdot d[i]$,这两个部分就可以**用两个树状数组分别维护**,每次区间修改都对两棵树进行修改,每次区间查询就进行一次运算即可。
---
完整代码: ([P3372](https://www.luogu.org/problem/P3372))
```cpp
#include <bits/stdc++.h>
#define lowbit(x) (x & (-x))
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
int n, m;
ll c1[MAXN], c2[MAXN], a[MAXN];
void add(ll *c, int x, int v)
{
while(x <= n)
{
c[x] += v;
x += lowbit(x);
}
}
ll query(ll *c, int x)
{
ll res = 0;
while(x)
{
res += c[x];
x -= lowbit(x);
}
return res;
}
void edit(int l, int r, ll k)
{
add(c1, l, k);
add(c1, r + 1, - k);
add(c2, l, k * (l - 1));
add(c2, r + 1, - k * r);
}
ll presum(int x)
{
return x * query(c1, x) - query(c2, x);
}
ll sum(int l, int r)
{
return presum(r) - presum(l - 1);
}
void init()
{
memset(c1, 0, sizeof(c1));
memset(c2, 0, sizeof(c2));
for(int i = 1; i <= n; i++)
{
add(c1, i, a[i] - a[i - 1]);
add(c2, i, (i - 1) * (a[i] - a[i - 1]));
}
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++)
scanf("%d", &a[i]);
init();
int opt, x, y;
ll k;
while(m--)
{
scanf("%d%d%d", &opt, &x, &y);
switch(opt)
{
case 1:
scanf("%lli", &k);
edit(x, y, k);
break;
case 2:
printf("%lli\n", sum(x, y));
break;
}
}
return 0;
}
```
# 二次改良
**例题** [LOJ 10115](https://loj.ac/problem/10115)
考虑把每次区间加的操作抽象为一对括号$()$,那么每次询问$[l,r]$区间有多少种树时,答案就可以转化为$[1,r]$区间内的左括号数减去$[1,l)$区间内的右括号数,因此直接使用两个树状数组维护左右括号数即可。
```cpp
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e4 + 10;
int c1[MAXN], c2[MAXN];
int n, m;
void add(int *c, int x, int v) {
for ( ; x <= n; x += x & (-x))
c[x] += v;
return ;
}
int query(int *c, int x) {
int res = 0;
for ( ; x; x -= x & (-x))
res += c[x];
return res;
}
int main() {
scanf("%d %d", &n, &m);
for (int k, l, r; m; m--) {
scanf("%d %d %d", &k, &l, &r);
if (k == 1) {
add(c1, l, 1);
add(c2, r, 1);
} else
printf("%d\n", query(c1, r) - query(c2, l - 1));
}
return 0;
}
```
---
**例题** [POJ 1990](http://poj.org/problem?id=1990)
# 分析
这道题乍一看是个$O(n^2)$,然而很明显$20000$的数据范围限定了复杂度只能是$O(nlogn)$。如果扫一遍所有奶牛的复杂度是$O(n)$,那么就必须在$O(logn)$时间内完成对一头牛的计算。考虑如何用树状数组做这道题:
首先,一头奶牛要对答案产生贡献,其必须与$v$小于自身的奶牛交流。这也就意味着,如果以$v$为关键字对原序列进行升序排序,那么在$v$的角度就转化为了一个**单调性问题**:每头奶牛与其之前的奶牛交流就会对答案产生贡献。问题在于如何处理“距离$\times$最大阈值$=$贡献”这个恶心的算式。
由于已经转化为一个单调性问题,所以不用再枚举每头奶牛,当前奶牛$i$能产生的贡献即是$i\times v[i]\times sum(\left|x[i]-x[j]\right|)\ (j<i)$。拆掉绝对值,就把贡献砍成两部分:左边的奶牛和右边的奶牛。
所以这里维护两个树状数组,$c1[]$维护坐标,$c2[]$维护个数;每次查询完左边信息以后,再利用左边信息获得右边信息,最后将当前奶牛加入树状数组当中。具体操作可以看代码。
```cpp
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int MAXN = 20010;
struct type_cow
{
int x, v;
bool operator < (const type_cow &rhs) const
{
return v < rhs.v;
}
} cow[MAXN];
int n, max_x = 0;
ll c1[MAXN], c2[MAXN];
// c1维护坐标,c2维护个数
int lowbit(int x)
{
return x & (-x);
}
void add(ll *c, int pos, ll v)
{
while(pos <= max_x)
{
c[pos] += v;
pos += lowbit(pos);
}
}
ll query(ll *c, int pos)
{
ll res = 0;
while(pos)
{
res += c[pos];
pos -= lowbit(pos);
}
return res;
}
int main()
{
memset(c1, 0, sizeof(c1));
memset(c2, 0, sizeof(c2));
scanf("%d", &n);
for(int i = 1; i <= n; i++)
{
scanf("%d%d", &cow[i].v, &cow[i].x);
max_x = max(max_x, cow[i].x);
}
sort(cow + 1, cow + n + 1);
ll ans = 0, dis, num;
for(int i = 1; i <= n; i++)
{
// left
dis = query(c1, cow[i].x);
num = query(c2, cow[i].x);
ans += (num * cow[i].x - dis) * cow[i].v;
// right
dis = query(c1, max_x) - dis;
num = (i - 1) - num;
ans += (dis - num * cow[i].x) * cow[i].v;
add(c1, cow[i].x, cow[i].x);
add(c2, cow[i].x, 1LL);
}
printf("%lli", ans);
return 0;
}
```