【模板】线段树

· · 个人记录

线段树一

#include <iostream>
#include <cstdio>
#define ll long long
using namespace std;
const int maxn = 100010;

ll n, m;
struct segtree {
    ll add, val;
}st[4*maxn+10];
inline ll read() {
    ll x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c=getchar();}
    return f * x;
}
inline void up(ll p) {
    st[p].val = st[2*p].val + st[2*p+1].val;
    return ;
}
inline void build(ll p, ll l, ll r) {
    st[p].add = 0;
    if(l == r) {
        st[p].val = read();
        return ;
    }
    ll m = (l + r) / 2;
    build(2*p, l, m);
    build(2*p+1, m+1, r);
    up(p);
    return ; 
}
inline void down(ll p, ll l, ll r) {
    if(st[p].add) {
        st[2*p].add += st[p].add;
        st[2*p+1].add += st[p].add;
        ll m = (l + r) / 2;
        st[2*p].val += (m - l + 1) * st[p].add;
        st[2*p+1].val += (r - m) * st[p].add;
        st[p].add = 0;
    }
    return ;
}
inline void update(ll p, ll l, ll r, ll x, ll y, ll v) {
    if(x <= l && y >= r) {
        st[p].add += v;
        st[p].val += (r - l + 1) * v;
        return ;
    }
    down(p, l, r);
    ll m = (l + r) / 2;
    if(x <= m) update(2*p, l, m, x, y, v);
    if(y > m) update(2*p+1, m+1, r, x, y, v);
    up(p);
    return ;
}
inline ll query(ll p, ll l, ll r, ll x, ll y){
    if(x <= l && y >= r) {
        return st[p].val;
    }
    down(p, l, r);
    ll m = (l + r) / 2; 
    ll ret = 0;
    if(x <= m) ret += query(2*p, l, m, x, y);
    if(y > m) ret += query(2*p+1, m+1, r, x, y);
    return ret;
}
int main() {
    n = read(), m = read();
    build(1, 1, n);
    for(ll i = 1; i <= m; i++) {
        ll tmp, a, b, x;
        tmp = read();
        if(tmp == 1) {
            a = read(), b = read(), x = read();
            update(1, 1, n, a, b, x);
        } else {
            a = read(), b = read();
            cout << query(1, 1, n, a, b) << endl;
        }
    }
    return 0;
}

线段树二

#include <iostream>
#include <cstdio>
#define ll long long 
using namespace std;
const int maxn = 100010;

ll mod, n, m;
ll a[maxn];
struct node{
    ll v, mul, add;
}st[4*maxn+10];
inline ll read() {
    ll x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
    return f * x;
}
inline void build(int p, int l, int r) {
    st[p].mul = 1;
    st[p].add = 0;
    if(l == r){
        st[p].v = a[l];
    } else {
        int m = (l + r) / 2;
        build(2*p, l, m);
        build(2*p+1, m+1, r);
        st[p].v = st[2*p].v + st[2*p+1].v;
    }
    st[p].v %= mod;
    return ;
}
inline void down(ll p, ll l, ll r) {
    ll m = (l + r) / 2;
    st[2*p].v = (st[2*p].v * st[p].mul + st[p].add * (m - l + 1)) % mod;
    st[2*p+1].v = (st[2*p+1].v * st[p].mul + st[p].add * (r - m)) % mod;
    st[2*p].mul = (st[2*p].mul * st[p].mul) % mod;
    st[2*p+1].mul = (st[2*p+1].mul * st[p].mul) % mod;
    st[2*p].add = (st[2*p].add * st[p].mul + st[p].add) % mod;
    st[2*p+1].add = (st[2*p+1].add * st[p].mul + st[p].add) % mod;
    st[p].mul = 1;
    st[p].add = 0;
    return ;
}
inline void update1(ll p, ll l, ll r, ll x, ll y, ll v) {
    if(y < l || x > r){
        return ;
    }
    if(x <= l && y >= r){
        st[p].v = (st[p].v*v) % mod;
        st[p].mul = (st[p].mul*v) % mod;
        st[p].add = (st[p].add*v) % mod;
        return ;
    }
    down(p, l, r);
    ll m = (l + r) / 2;
    update1(p*2, l, m, x, y, v);
    update1(p*2+1, m+1, r, x, y, v);
    st[p].v = (st[p*2].v + st[p*2+1].v) % mod;
    return ;
}
inline void update2(ll p, ll l, ll r, ll x, ll y, ll v) {
    if(y < l || x > r){
        return ;
    }
    if(x <= l && y >= r){
        st[p].add = (st[p].add+v) % mod;
        st[p].v = (st[p].v + v * (r-l+1)) % mod;
        return ;
    }
    down(p, l, r);
    ll m = (l + r) / 2;
    update2(p*2, l, m, x, y, v);
    update2(p*2+1, m+1, r, x, y, v);
    st[p].v = (st[p*2].v + st[p*2+1].v) % mod;
    return ;
}
inline ll query(ll p, ll l, ll r, ll x, ll y) {
    if(y < l || x > r){
        return 0;
    }
    if(x <= l && r <= y){
        return st[p].v;
    }
    down(p, l, r);
    ll m = (l + r) / 2;
    return (query(2*p, l, m, x, y) + query(2*p+1, m+1, r, x, y)) % mod;
}

int main(){
    n = read(), m = read(), mod = read();
    for(int i=1; i<=n; i++){
        a[i] = read();
    }
    build(1, 1, n);
    while(m--){
        ll tmp, a, b, x;
        tmp = read();
        if(tmp == 1){
            a = read(), b = read(), x = read();
            update1(1, 1, n, a, b, x);
        }
        else if(tmp == 2){
            a = read(), b = read(), x = read();
            update2(1, 1, n, a, b, x);
        } else {
            a = read(), b = read();
            printf("%lld\n", query(1, 1, n, a, b));
        }
    }
    return 0;
}