线段树学习笔记

· · 个人记录

线段树可以在 O(logn) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

建树

void build(int s, int t, int p) {
    // 对 [s,t] 区间建立线段树,当前根的编号为 p
    if (s == t) {
        d[p] = a[s];
        return;
    }
    int m = (s + t) / 2;
    build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
    // 递归对左右区间建树
    d[p] = d[p * 2] + d[(p * 2) + 1];
}

区间查询(区间求和)

int getsum(int l, int r, int s, int t, int p) {
    // [l,r] 为查询区间,[s,t] 为当前节点包含的区间,p为当前节点的编号
    if (l <= s && t <= r) return d[p];
    // 当前区间为询问区间的子集时直接返回当前区间的和
    int m = (s + t) / 2;
    if (b[p]) {
        // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
        d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
        b[p * 2] += b[p], b[p * 2 + 1] += b[p];  // 将标记下传给子节点
        b[p] = 0;                                    // 清空当前节点的标记
    }
    int sum = 0;
    if (l <= m) sum = getsum(l, r, s, m, p * 2);
    if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
    return sum;
}

区间修改(区间加上某个值)

void update(int l, int r, int c, int s, int t, int p) {
    // [l,r] 为修改区间,c 为被修改的元素的变化量,[s,t] 为当前节点包含的区间,p
    // 为当前节点的编号
    if (l <= s && t <= r) {
        d[p] += (t - s + 1) * c, b[p] += c;
        return;
    }  // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改
    int m = (s + t) / 2;
    if (b[p] && s != t) {
        // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
        d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
        b[p * 2] += b[p], b[p * 2 + 1] += b[p];  // 将标记下传给子节点
        b[p] = 0;                                // 清空当前节点的标记
    }
    if (l <= m) update(l, r, c, s, m, p * 2);
    if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
    d[p] = d[p * 2] + d[p * 2 + 1];
}
#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, a[100005], d[270005], b[270005];
int q;
void build(int l, int r, int p) {
    if (l == r) {
        d[p] = a[l];
        return;
    }
    int m = (l + r) >> 1;
    build(l, m, p << 1), build(m + 1, r, (p << 1) | 1);
    d[p] = d[p << 1] + d[(p << 1) | 1];
}
void update(int l, int r, int c, int s, int t, int p) {
    if (l <= s && t <= r) {
        d[p] += (t - s + 1) * c, b[p] += c;
        return;
    }
    int m = (s + t) >> 1;
    if (b[p]) {
        d[p << 1] += b[p] * (m - s + 1), d[(p << 1) | 1] += b[p] * (t - m);
        b[p << 1] += b[p], b[(p << 1) | 1] += b[p];
    }
    b[p] = 0;
    if (l <= m) update(l, r, c, s, m, p << 1);
    if (r > m) update(l, r, c, m + 1, t, (p << 1) | 1);
    d[p] = d[p << 1] + d[(p << 1) | 1];
}
int getsum(int l, int r, int s, int t, int p) {
    if (l <= s && t <= r) return d[p];
    int m = (s + t) >> 1;
    if (b[p]) {
        d[p << 1] += b[p] * (m - s + 1), d[(p << 1) | 1] += b[p] * (t - m);
        b[p << 1] += b[p], b[(p << 1) | 1] += b[p];
    }
    b[p] = 0;
    int sum = 0;
    if (l <= m) sum = getsum(l, r, s, m, p << 1);
    if (r > m) sum += getsum(l, r, m + 1, t, (p << 1) | 1);
    return sum;
}
signed main() {
    cin>>n>>q;
    for(int i=1; i<=n; i++) scanf("%lld", &a[i]);
    build(1, n, 1);
    while(q--) {
        int opt, x, y, k;
        scanf("%lld%lld%lld", &opt, &x, &y);
        if (opt == 2) printf("%lld\n", getsum(x, y, 1, n, 1));
        else {
            scanf("%lld", &k);
            update(x, y, k, 1, n, 1);
        }
    }
    return 0;
}
#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, a[1000005], d[2700005], b[2700005];
int q;
void build(int l, int r, int p) {
    if (l == r) {
        d[p] = a[l];
        return;
    }
    int m = (l + r) >> 1;
    build(l, m, p << 1), build(m + 1, r, (p << 1) | 1);
    d[p] = d[p << 1] + d[(p << 1) | 1];
}
void update(int l, int r, int c, int s, int t, int p) {
    if (l <= s && t <= r) {
        d[p] += (t - s + 1) * c, b[p] += c;
        return;
    }
    int m = (s + t) >> 1;
    if (b[p]) {
        d[p << 1] += b[p] * (m - s + 1), d[(p << 1) | 1] += b[p] * (t - m);
        b[p << 1] += b[p], b[(p << 1) | 1] += b[p];
    }
    b[p] = 0;
    if (l <= m) update(l, r, c, s, m, p << 1);
    if (r > m) update(l, r, c, m + 1, t, (p << 1) | 1);
    d[p] = d[p << 1] + d[(p << 1) | 1];
}
int getsum(int l, int r, int s, int t, int p) {
    if (l <= s && t <= r) return d[p];
    int m = (s + t) >> 1;
    if (b[p]) {
        d[p << 1] += b[p] * (m - s + 1), d[(p << 1) | 1] += b[p] * (t - m);
        b[p << 1] += b[p], b[(p << 1) | 1] += b[p];
    }
    b[p] = 0;
    int sum = 0;
    if (l <= m) sum = getsum(l, r, s, m, p << 1);
    if (r > m) sum += getsum(l, r, m + 1, t, (p << 1) | 1);
    return sum;
}
signed main() {
    cin>>n>>q;
    for(int i=1; i<=n; i++) scanf("%lld", &a[i]);
    build(1, n, 1);
    while(q--) {
        int opt, x, k;
        scanf("%lld%lld%lld", &opt, &x, &k);
        if (opt == 2) printf("%lld\n", getsum(x, k, 1, n, 1));
        else update(x, x, k, 1, n, 1);
    }
    return 0;
}