【学习笔记】树状数组

NCC79601

2019-08-08 18:58:45

Personal

# 定义 树状数组就是为前缀数组建立的树形结构,其最朴素的应用是**单点修改,区间查询**。 对于树状数组中的每一个节点$c[i]$,其管辖$[i-lowbit(i)+1,\ i]$这一段区间,其中$lowbit(i)$计算的是$i$的二进制数位中最靠右的$1$所表示的数。例如$6_{(10)}=110_{(2)}$,那么$c[6]$管辖的即是$[5,6]$这个区间。 可以证明一个数$n$的二进制数位中最多只有$log(n)$个$1$,因此树状数组的复杂度在**最坏情况下**为$O(logn)$,而最优情况下为$O(1)$。由于树状数组的本质是一个**特殊的前缀数组**,因此空间开销为$n$,这比树状数组的稳定$O(logn)$复杂度、空间开销$4n$都要优秀。同时,树状数组的代码量远远小于线段树,因此在其应用范围内,树状数组不失为一个优秀的选择。 # 改良 由于树状数组特性,其原生只支持单点修改、区间查询;而面临**区间修改、区间查询**的情况,树状数组似乎失去了作用。实际上,树状数组也能够进行区间修改、区间查询。 考虑一个差分数组$d[]$,可以知道$a[n]=\sum_{i=1}^n d[i]$。那么: $$s[n]=\sum_{i=1}^n a[i]=\sum_{i=1}^n \sum_{j=1}^i d[j]$$ $$=n\cdot d[1]+(n-1)\cdot d[2]+\cdots+2\cdot d[n-1] + 1\cdot d[n].$$ 对这个式子进行处理,可以得到: $$\text{原式}=n\cdot(d[1]+d[2]+\cdots+d[n])-(0\cdot d[1]+1\cdot d[2]+\cdots+(n-1)\cdot d[n]).$$ 也就是说,$s[n]$可以拆成两个部分,一个是$n\cdot\sum_{i=1}^nd[i]$,另一个是$-\sum_{i=1}^n(i-1)\cdot d[i]$,这两个部分就可以**用两个树状数组分别维护**,每次区间修改都对两棵树进行修改,每次区间查询就进行一次运算即可。 --- 完整代码: ([P3372](https://www.luogu.org/problem/P3372)) ```cpp #include <bits/stdc++.h> #define lowbit(x) (x & (-x)) using namespace std; typedef long long ll; const int MAXN = 1e5 + 10; int n, m; ll c1[MAXN], c2[MAXN], a[MAXN]; void add(ll *c, int x, int v) { while(x <= n) { c[x] += v; x += lowbit(x); } } ll query(ll *c, int x) { ll res = 0; while(x) { res += c[x]; x -= lowbit(x); } return res; } void edit(int l, int r, ll k) { add(c1, l, k); add(c1, r + 1, - k); add(c2, l, k * (l - 1)); add(c2, r + 1, - k * r); } ll presum(int x) { return x * query(c1, x) - query(c2, x); } ll sum(int l, int r) { return presum(r) - presum(l - 1); } void init() { memset(c1, 0, sizeof(c1)); memset(c2, 0, sizeof(c2)); for(int i = 1; i <= n; i++) { add(c1, i, a[i] - a[i - 1]); add(c2, i, (i - 1) * (a[i] - a[i - 1])); } } int main() { scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++) scanf("%d", &a[i]); init(); int opt, x, y; ll k; while(m--) { scanf("%d%d%d", &opt, &x, &y); switch(opt) { case 1: scanf("%lli", &k); edit(x, y, k); break; case 2: printf("%lli\n", sum(x, y)); break; } } return 0; } ``` # 二次改良 **例题** [LOJ 10115](https://loj.ac/problem/10115) 考虑把每次区间加的操作抽象为一对括号$()$,那么每次询问$[l,r]$区间有多少种树时,答案就可以转化为$[1,r]$区间内的左括号数减去$[1,l)$区间内的右括号数,因此直接使用两个树状数组维护左右括号数即可。 ```cpp #include <bits/stdc++.h> using namespace std; const int MAXN = 5e4 + 10; int c1[MAXN], c2[MAXN]; int n, m; void add(int *c, int x, int v) { for ( ; x <= n; x += x & (-x)) c[x] += v; return ; } int query(int *c, int x) { int res = 0; for ( ; x; x -= x & (-x)) res += c[x]; return res; } int main() { scanf("%d %d", &n, &m); for (int k, l, r; m; m--) { scanf("%d %d %d", &k, &l, &r); if (k == 1) { add(c1, l, 1); add(c2, r, 1); } else printf("%d\n", query(c1, r) - query(c2, l - 1)); } return 0; } ``` --- **例题** [POJ 1990](http://poj.org/problem?id=1990) # 分析 这道题乍一看是个$O(n^2)$,然而很明显$20000$的数据范围限定了复杂度只能是$O(nlogn)$。如果扫一遍所有奶牛的复杂度是$O(n)$,那么就必须在$O(logn)$时间内完成对一头牛的计算。考虑如何用树状数组做这道题: 首先,一头奶牛要对答案产生贡献,其必须与$v$小于自身的奶牛交流。这也就意味着,如果以$v$为关键字对原序列进行升序排序,那么在$v$的角度就转化为了一个**单调性问题**:每头奶牛与其之前的奶牛交流就会对答案产生贡献。问题在于如何处理“距离$\times$最大阈值$=$贡献”这个恶心的算式。 由于已经转化为一个单调性问题,所以不用再枚举每头奶牛,当前奶牛$i$能产生的贡献即是$i\times v[i]\times sum(\left|x[i]-x[j]\right|)\ (j<i)$。拆掉绝对值,就把贡献砍成两部分:左边的奶牛和右边的奶牛。 所以这里维护两个树状数组,$c1[]$维护坐标,$c2[]$维护个数;每次查询完左边信息以后,再利用左边信息获得右边信息,最后将当前奶牛加入树状数组当中。具体操作可以看代码。 ```cpp #include <iostream> #include <stdio.h> #include <algorithm> #include <cstring> using namespace std; typedef long long ll; const int MAXN = 20010; struct type_cow { int x, v; bool operator < (const type_cow &rhs) const { return v < rhs.v; } } cow[MAXN]; int n, max_x = 0; ll c1[MAXN], c2[MAXN]; // c1维护坐标,c2维护个数 int lowbit(int x) { return x & (-x); } void add(ll *c, int pos, ll v) { while(pos <= max_x) { c[pos] += v; pos += lowbit(pos); } } ll query(ll *c, int pos) { ll res = 0; while(pos) { res += c[pos]; pos -= lowbit(pos); } return res; } int main() { memset(c1, 0, sizeof(c1)); memset(c2, 0, sizeof(c2)); scanf("%d", &n); for(int i = 1; i <= n; i++) { scanf("%d%d", &cow[i].v, &cow[i].x); max_x = max(max_x, cow[i].x); } sort(cow + 1, cow + n + 1); ll ans = 0, dis, num; for(int i = 1; i <= n; i++) { // left dis = query(c1, cow[i].x); num = query(c2, cow[i].x); ans += (num * cow[i].x - dis) * cow[i].v; // right dis = query(c1, max_x) - dis; num = (i - 1) - num; ans += (dis - num * cow[i].x) * cow[i].v; add(c1, cow[i].x, cow[i].x); add(c2, cow[i].x, 1LL); } printf("%lli", ans); return 0; } ```