线段树(Segment Tree)

· · 算法·理论

你说得对,但是线段树是一种用于维护区间信息的可修改的开放式数据结构。

关于线段树...

1.芝士什么

2.药肿么写

SO~

FIRST

什么是线段树?

线段树是一种利用二分思想,将一个长度不为1的区间划分为左右两个区间进行递归求解的数据结构。它能够将一条线段划分为一个树形结构,通过合并左右两区间信息来求得该区间的信息。

下图为一棵线段树形态(求和):

在以上树形图中,对于每一个结点都代表一个区间内所有之和。

其中,我们可以发现,对于∀sum[i](设代表区间[l,r]), i ∈[1,7],其左节点为sum[i2],代表区间[l,(l+r)/2];右节点为sum[i 2+1],代表区间[(l+r)/2+1,r]。

了解了什么是线段树,接下来我们来进行代码实现

SECOND

怎么写线段树

由于线段树是一种树形结构,其原始数据分布在叶子节点上,因此在实现时通常考虑递归建树。具体代码如下:

void build(int l, int r, int k) {
    //对区间[l,r]建树,当前节点为k
    if (l == r) {
        sum[k] = a[l];
        return;
    }//当左右端点相等时赋值并返回
    int mid = (l + r) / 2;//取区间中点
    build(l, mid, k * 2);//左子节点建树
    build(mid + 1, r, k * 2 + 1);//右子节点建树
    sum[k] = sum[k * 2] + sum[k * 2 + 1];//合并
    return;
}

值得注意的是,当我们在定义线段树空间大小时,最好定义为原数组的 4倍 大小。这里不阐述原因(单纯是懒)

完成建树后,我们要对该树进行区间查询操作。具体代码如下:

int query(int l, int r, int k, int x, int y) {
    //l,r表示当前区间的左右端点,k表示当前节点,x,y表示查询区间
    if (l >= x && r <= y) return sum[k];
    //当前区间为查询区间的子集时,返回该节点的值
    if (l > y || r < x) return 0;
    //当前区间与查询区间无交集时,返回0(不影响求和)
    int mid = (l + r) / 2, res = 0;
    res += query(l, mid, k * 2, x, y);
    res += query(mid + 1, r, k * 2 + 1, x, y);BUT
    //处理左右子区间求和
    return res;//返回当前区间答案
}

以上两个操作看似十分完美...

BUT

线段树在进行建树与查询的同时,也要求能够进行修改(一般是区间修改)

进行区间修改时为了节省时间,我们考虑使用一个 延迟标记 来记录当前节点的子节点所需要修改的量,并在查询修改操作中不断进行下传。具体代码如下:

void Add(int l, int r, int k, int v) {
    //对指定节点进行修改与标记,l,r为区间端点,k为节点,v为待加数
    add[k] += v;
    sum[k] += v * (r - l + 1);
    return;
}
void pushdown(int l, int r, int k) {
    //标记下传,l,r为区间端点,k为节点
    if (add[k] == 0) return;//若无标记,则不下传
    int mid = (l + r) / 2;
    Add(l, mid, k * 2, add[k]);//修改左子节点
    Add(mid + 1, r, k * 2 + 1, add[k]);//修改右子节点
    add[k] = 0;//标记归零
    return;
}
void modify(int l, int r, int k, int x, int y, int v) {
    //区间修改
    //l,r为区间端点,k为节点,x,y为待修改区间,v为加数
    if (l >= x && r <= y) {
        Add(l, r, k, v);
        return;
    }//当前区间为查询区间的子集时,进行标记修改
    if (l > y || r < x) return;
    pushdown(l, r, k);//标记下传
    int mid = (l + r) / 2;
    modify(l, mid, k * 2, x, y, v);
    modify(mid + 1, r, k * 2 + 1, x, y, v);
    //修改左右子区间
    sum[k] = sum[k * 2] + sum[k * 2 + 1];
    return;
}

与此同时,区间查询操作也需要进行修改:

int query(int l, int r, int k, int x, int y) {
    if (l >= x && r <= y) return sum[k];
    if (l > y || r < x) return 0;
    pushdown(l, r, k);//进行标记下传
    int mid = (l + r) / 2, res1, res2;
    res1 = query(l, mid, k * 2, x, y);
    res2 = query(mid + 1, r, k * 2 + 1, x, y);
    return res1 + res2;
}

这里由于,我们查询时访问到的节点可能存在没有被下传的标记导致结果错误。因此在查询的时候也要记得下传标记。

到这里我们的线段树就完美的结束了

完整代码如下:

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 5;
int a[N];
int sum[4 * N], add[4 * N];
int n, m;
void build(int l, int r, int k) {
    if (l == r) {
        sum[k] = a[l];
        return;
    }
    int mid = (l + r) / 2;
    build(l, mid, k * 2);
    build(mid + 1, r, k * 2 + 1);
    sum[k] = sum[k * 2] + sum[k * 2 + 1];
    return;
}
void Add(int l, int r, int k, int v) {
    add[k] += v;
    sum[k] += v * (r - l + 1);
    return;
}
void pushdown(int l, int r, int k) {
    if (add[k] == 0) return;
    int mid = (l + r) / 2;
    Add(l, mid, k * 2, add[k]);
    Add(mid + 1, r, k * 2 + 1, add[k]);
    add[k] = 0;
    return;
}
void modify(int l, int r, int k, int x, int y, int v) {
    if (l >= x && r <= y) {
        Add(l, r, k, v);
        return;
    }
    if (l > y || r < x) return;
    pushdown(l, r, k);
    int mid = (l + r) / 2;
    modify(l, mid, k * 2, x, y, v);
    modify(mid + 1, r, k * 2 + 1, x, y, v);
    sum[k] = sum[k * 2] + sum[k * 2 + 1];
    return;
}
int query(int l, int r, int k, int x, int y) {
    if (l >= x && r <= y) return sum[k];
    if (l > y || r < x) return 0;
    pushdown(l, r, k);
    int mid = (l + r) / 2, res1, res2;
    res1 = query(l, mid, k * 2, x, y);
    res2 = query(mid + 1, r, k * 2 + 1, x, y);
    return res1 + res2;
}
signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    build(1, n, 1);
    while (m--) {
        int op, x, y, k;
        cin >> op >> x >> y;
        if (op == 1) {
            cin >> k;
            modify(1, n, 1, x, y, k);
        } else {
            cout << query(1, n, 1, x, y) << "\n";
        }
    }
    return 0;
}

THE END