算法学习——zkw线段树

· · 个人记录

算法学习——zkw线段树

zkw线段树是什么?

zkw线段树简单来说,就是非递归线段树。为什么叫zkw线段树?这是因为张昆玮dalao在《统计的力量》中很好地介绍了这个数据结构,所以常常被称为zkw(名字缩写)线段树。普通的线段树采用的是自顶向下的构树方式,而zkw线段树却反其道而行之——自下向顶的进行构建树。

为什么要使用zkw线段树

众所周知,线段树最大的缺点就是常数大,且代码量大,尤其是在添加lazy标记后的代码量,而且普通的线段树采用的递归方式的写法,所以很难让入门的同学理解。

而zkw线段树,采用的是自下向顶的构建树的方式,只用循环便可以搞定线段树的构树,修改和查询。这不仅通俗易懂,而且常数很小,代码量也很少,并且由于zkw线段树的特殊建树方式,使它的结构是个非常优美的完全二叉树。

zkw线段树的原理

就像上面说的那样,zkw线段树是自下向顶的构建树,也就是说,我们会先确定好每个最底层的节点所对应的值。然后一层层的去更新他们的父亲节点。这样来达到我们构建一颗线段树的目的。

而要找到一个孩子节点的父亲也很简单,我们只需要让节点x进行除以二操作即可x >>= 1

zkw线段树是一颗完全二叉树,也就是说除根结点以外,每一个节点都会有自己的兄弟节点,而兄弟节点的编号为x^1(这里的^是异或操作)。例如3号节点,它的兄弟节点便是3^1 = 2,即2号节点。后面我们也会频繁的使用到^操作来完成我们寻找兄弟节点的操作。

构建树

首先我们先了解zkw线段树的结构,这里我们用一组长度为5的数据[5, 1, 2, 3, 4]来表示。

最底层总共有多少个节点?

首先我们要先知道用于储存原始数据的那一层的节点有多少个。

我们先算一算需要多少个才能够将我们长度为n的数据全部储存在统一深度的节点中,由于zkw线段树是完全二叉树,也就是说储存这一层的总结点的数量是二的倍数(即2^k个(这里的^是次方操作))。显然,要存储长度为n个数据的话,我们就必须让2^k大于等于n才行。

所以答案是大于等于n的最小的2的倍数。。。。吗?

先说结论,答案是大于等于(n + 2)的最小的2的倍数。至于为什么,会在区间操作那里提到,现在我们只需要知道要大于等于的是(n + 2),而不是n即可。

例如上面长度为5(即n = 5),最小的大于等于7(5+2)的2的倍数是8,所以我们储存原始数据那一层的节点个数为8。

如何建树?

刚刚我们求出了最底层(即储存原始数据的那一层)有多少个节点,而这有什么用呢?

首先我们先了解原始数据储存在最底层的那些节点中。由图可以得知,分别储存在下标为9(1001), 10(1010) ...... 15(1101)中(至于为什么不是从8开始,这个问题也会在区间操作中讲到)。我们可以观察到储存数据的下标为8+i,8是大于等于(n + 2)的最小的2的倍数,i为数据中的第i个数据。

所以我们可以先求出来大于等于(n + 2)的值,接下来只需要让这个值+i为下标便可与储存相对应的第i个数据。

首先求出大于等于(n + 2)的值

for (m = 1; m < n + 2; m <<= 1);

注:这里的m是全局变量

先让m等于2的0次方(即1),如果m小于n+2的话就让m*=2,直到m大于等于n+2

接下来是存储数据

for (int i = m + 1; i <= m + n; i ++)
    cin >> sum[i];

最后一步与线段树相同:更新他们的父节点。

for (int i = m - 1; i; i --)
    sum[i] = sum[i << 1] + sum[i << 1 | 1];

这里让i = m - 1是由于m - 1为上一层的最大节点(m为最底层的第一个节点),也就是说从m-1到1,全部都是我们需要更新的节点。(相信会线段树的你也一定一眼就理解这一段了)

全代码

void build(int n) {
    for (m = 1; m < n + 2; m <<= 1);
    for (int i = m + 1; i <= m + n; i ++)
        cin >> sum[i];
    for (int i = m - 1; i; i --)
        sum[i] = sum[i << 1] + sum[i << 1 | 1];
}

单点修改

单点修改的思想十分简单,由于我们可以很轻松的知道需要改变的数的节点编号是x(需要改变第x位置的数) + m(上文那个全局变量的m)。所以我们仅仅需要像树状数组那样更新它的父节点即可。

例如这里我们让第三个数+4,也就是原数组变为[5, 1, 6, 3, 4]

第一步先找到储存在zkw线段树的节点的位置,即x += m。由于m设置的是全局变量,所以我们可以直接调用建树时得到的m的值。

接下来只需要更新它的父节点即可,而它的父节点的节点编号可以像树状数组一样的方式找到x >>= 1。最后更新一下sum数组sum[x] += v;

void update_node(int x, int v) {
    for (x += m; x; x >>= 1)
        sum[x] += v;
}

区间查询

首先回答一下为什么上文中多次提到的(n + 2)

这是由于我们在进行区间修改区间查询的时候,所查询的空间并非闭区间的[x, y],而是开区间的(x - 1, y + 1)(至于为什么是开区间后面会说到)。n+2方便我们空出来m+0的节点与m+n+1的节点位置,这也方便我们直观的把第i个数据存储到相应的m+i节点上。如果我们只要m大于等于n的话,那么我们存储数据应该是从m到m+n-1,例如n = 7,我们存储的节点编号便是8~15,我们要修改区间1到7的数的话,则我们进行操作的区间是(7, 16),而7号节点是一个父亲节点,这样就会变得十分麻烦。而我们如果采用n+2的话,让数据全部放在m, m + n + 1之间,这样子哪怕是全部修改,也可以保证开区间的端点都在最底层节点中。

接下来解释为什么是采用开区间(x - 1, y + 1)的方式

这里其实是zkw线段树的几个性质相关,对于查询区间[x, y]

  1. 如果x是左儿子,那么x兄弟节点必在区间内
  2. 如果y是右儿子,那么y兄弟节点必在区间内
  3. 如果x是右儿子,那么直接去向上找它的父亲
  4. 如果y是左儿子,那么直接去向上找它的父亲
  5. 如果x与y同为兄弟,那么终止查询

只看这个的话想必是比较难以理解的,那么这里我们以查询区间[2, 5]为例。

首先转换为开区间(1, 6),x = 1 + 8(m) = 9 , y = 6 + 8 = 14。

x += m - 1, y += m + 1

发现符合x与y是完全包含所查询区间的,而且由于是开区间,也并不会把x,y只想的数包含进去。

接下来我们发现这符合上面的3,4条,所以我们直接让x,y再去指向它们的父亲节点x >>= 1, y >>= 1

直到x与y为兄弟节点时结束x ^ y ^ 1 == 0;

这时候我们发现x是左儿子节点,则x的兄弟节点x ^ 1 = 0101必在区间内。y是右儿子节点,则y的兄弟节点y ^ 1 = 0110 必在区间内。

​ 判断是否位左儿子节点的方法:如果x是左儿子节点,那么它必然是偶数,即~ x & 1 == 1。右节点必为奇数即y & 1 == 1

所以我们要做的就是把x,y节点的兄弟节点的值加进我们定义的答案数组中

if (~ x & 1) ans += sum[x ^ 1];//x ^ 1就是x的兄弟节点(上面说过)
if (y & 1) ans += sum[y ^ 1];

下一步我们发现x与y是兄弟节点x ^ y ^ 1 == 0所以我们的查询也就结束了

上述例子,就是zkw线段树的区间查询操作。

全代码

int query(int x, int y) {
    int ans = 0;
    for (x += m - 1, y += m + 1; x ^ y ^ 1; x >>= 1, y >>= 1) {
        if (~ x & 1) ans += sum[x ^ 1];
        if (y & 1) ans += sum[y ^ 1];
    }
    return ans;
}

区间修改

在传统线段树中,区间修改采用了懒标记的方法来降低时间复杂度,但这会让代码量暴增。

在zkw线段树中,由于是非递归的,所有标记下传是比较困难的,所以我们选择了标记永久化的方法。mark[u]用于记录节点u的数值变化,即永久标记化数组

更详细地说,与上文所说的区间查询类似,当左端点是左儿子/右端点是右儿子时,那么其兄弟节点一定是所要修改的区间内的节点,那么我们就需要对这个兄弟节点做上标记,与标记下传不同的是,查询时要将这个标记上传到各级祖先,直到根节点。

这里我们以修改区间[3,5]为例,让这个区间整体+1

由于修改时所使用的区间为开区间,所以两个端点分别为m + 2,m + 6

这时我们发现左端点是左儿子,也就是说它的兄弟节点在我们要修改的区间内,这里我们要做的有两件事——修改兄弟节点的值和修改兄弟节点的永久标记化数组。

if (~ x & 1) sum[x ^ 1] += v, mark[x ^ 1] += v;
if (y & 1)   sum[y ^ 1] += v, mark[y ^ 1] += v;

到了第三层节点,我们发现左端点虽然是右节点,但由于它是修改区间的父亲节点,所以它的值依然需要更新,那么我们就需要再增加一个操作——每次循环都更新一下当前节点。

而这里我们需要引入另一个概念——len

len我们定义其为实际修改的区间长度,例如在第一层循环中,如果需要修改的话,那么一定是只修改长度为一的值。也就是len=1。例如在第二层循环中,我们发现y节点是右儿子,也就是说它的兄弟节点是我们需要去修改的,显然要让它的兄弟节点的值+2才对,因为这个兄弟节点包含了两个需要修改值的节点,也就是len = 2。我们可以发现每层循环都需要让len的值2,即len <<= 1,而对端点的兄弟节点值的修改,也就是`len v`

可这还并没有解决更新当前节点的问题

所以我们还需要另外两个变量xl, yl,它们分别代表左/右端点实际修改的区间长度。例如在第一层循环的时候,显然是不能够更新当前端点节点的,所以我们要将这两个变量都初始化为0,xl = 0, yl = 0。到了第二层循环,由于在第一层循环中左端点的兄弟节点更新过,而此时左端点变成了父亲节点,即这个节点实际上修改区间为1,即xl += len(要在上一层循环更新),要更新当前节点的值也很容易想到了,*实际修改长度增减的值*,即`sum[x] += xl v, sum[y] += yl * v;`

而我们刚刚在上面说的修改兄弟节点和永久标记化数组的地方也需要改下:

if (~ x & 1) sum[x ^ 1] += v * len, xl += len, mark[x ^ 1] += v;
if (y & 1) sum[y ^ 1] += v * len, yl += len, mark[y ^ 1] += v;
//mark数组只加v是因为在以后查询的时候,我们会让其乘上实际修改长度,所以现在并不需要乘len,仅标记此节点被修改过即可。

到了下一层循环,左端点的实际修改长度为一xl = 1,右端点的修改长度为二yl = 2。而此时由于是兄弟节点,我们需要退出循环。

退出循环后,我们还没有修改当前端点的值,所以需要再用一个while循环一步步更新父亲节点的值,直到更新到根节点。

while (x) {
    sum[x] += xl * v;
    sum[y] += yl * v;
    x >>= 1, y >>= 1;
}

全代码

void update_sum(int x, int y, int v) {
    int len = 1, xl = 0, yl = 0;
    for (x += m - 1, y += m + 1; x ^ y ^ 1; x >>= 1, y >>= 1, len <<= 1) {
        sum[x] += xl * v, sum[y] += yl * v;
        if (~ x & 1) sum[x ^ 1] += v * len, xl += len, mark[x ^ 1] += v;
        if (y & 1) sum[y ^ 1] += v * len, yl += len, mark[y ^ 1] += v;
    }
    while (x) {
        sum[x] += xl * v;
        sum[y] += yl * v;
        x >>= 1, y >>= 1;
    }
}

加上了永久化标记后的区间查询

区间修改会影响区间查询,我们同样需要在区间查询中引入len, xl, yl来实现对永久化标记的利用。

当左/右端点是左/右儿子时,我们同样需要更新一下实际修改区间的值xl + len,同时加上兄弟端点的值(这个和原本的一样)

而对于每个节点,只需要加上这个节点*实际的修改长度永久化标记数组记录的修改的值*即可,即`ans += xl mark[x] + yl * mark[y];`

由于永久化标记数组并不会下传标记,所以我们在左右端点都是兄弟节点退出循环后,还需要继续向上推到根节点

while (x) ans += xl * mark[x] + yl * mark[y], x >>= 1, y >>= 1;

全代码

int query_sum(int x, int y) {
    int ans = 0;
    int len = 1, xl = 0, yl = 0;
    for (x = x + m - 1, y = y + m + 1; x ^ y ^ 1; x >>= 1, y >>= 1, len <<= 1) {
        ans += xl * mark[x] + yl * mark[y];
        if (~ x & 1) ans += sum[x ^ 1],xl += len;
        if (y & 1) ans += sum[y ^ 1], yl += len;
    }
    while (x) ans += xl * mark[x] + yl * mark[y], x >>= 1, y >>= 1;
    return ans;
}

全代码

#include <bits/stdc++.h>

#define up(i, j, n, k) for (int i = j; i <= n; i += k)
#define dw(i, j, n, k) for (int i = j; i >= n; i -= k)
#define LL long long
#define map unordered_mao

using namespace std;

const int mod = 1e9 + 7;
const int mn = 1e6 + 10;
const int INF = 0x3f3f3f3f;

int sum[mn], m, n, k, mark[mn];

void build(int n) {
    for (m = 1; m < n + 2; m <<= 1);
    for (int i = m + 1; i <= m + n; i ++)
        cin >> sum[i];
    for (int i = m - 1; i; i --)
        sum[i] = sum[i << 1] + sum[i << 1 | 1];
}

void update_node(int x, int v) {
    for (x += m; x; x >>= 1)
        sum[x] += v;
}

int query(int x, int y) {
    int ans = 0;
    for (x += m - 1, y += m + 1; x ^ y ^ 1; x >>= 1, y >>= 1) {
        if (~ x & 1) ans += sum[x ^ 1];
        if (y & 1) ans += sum[y ^ 1];
    }
    return ans;
}

void update_sum(int x, int y, int v) {
    int len = 1, xl = 0, yl = 0;
    for (x += m - 1, y += m + 1; x ^ y ^ 1; x >>= 1, y >>= 1, len <<= 1) {
        sum[x] += xl * v, sum[y] += yl * v;
        if (~ x & 1) sum[x ^ 1] += v * len, xl += len, mark[x ^ 1] += v;
        if (y & 1) sum[y ^ 1] += v * len, yl += len, mark[y ^ 1] += v;
    }
    while (x) {
        sum[x] += xl * v;
        sum[y] += yl * v;
        x >>= 1, y >>= 1;
    }
}

int query_sum(int x, int y) {
    int ans = 0;
    int len = 1, xl = 0, yl = 0;
    for (x = x + m - 1, y = y + m + 1; x ^ y ^ 1; x >>= 1, y >>= 1, len <<= 1) {
        ans += xl * mark[x] + yl * mark[y];
        if (~ x & 1) ans += sum[x ^ 1],xl += len;
        if (y & 1) ans += sum[y ^ 1], yl += len;
    }
    while (x) ans += xl * mark[x] + yl * mark[y], x >>= 1, y >>= 1;
    return ans;
}

int main() {
    cin >> n >> k;
    build(n);
    up (i, 1, k, 1) {
        int flag, x, y, v;
        cin >> flag >> x >> y;
        if (flag == 1) {
            cin >> v;
            update_sum(x, y, v);
        }
        else
            cout << query_sum(x, y) << "\n";
    }
    return 0;
}