P3372 【模板】线段树 1? 讲解

· · 题解

题目
(用树状数组写线段树)

树状数组 1 是单点修改区间查询,树状数组 2 是区间修改单点查询,那么怎么实现区间修改区间查询呢?
既然要区间修改,那么我们依然要用差分(设原数组为 a,差分数组为 p):

\begin{aligned}\sum_{i=1}^na_i&=\sum_{i=1}^n\sum_{j=1}^ip_j\\&=\sum_{i=1}^n(n-i+1)p_i\end{aligned}

但由于 n 是不固定的,我们无法快速得到 \sum_{i=1}^n(n-i+1)p_i
正难则反。

\begin{aligned}\sum_{i=1}^n(n-i+1)p_i&=\sum_{i=1}^nnp_i-(i-1)p_i\\&=(\sum_{i=1}^nnp_i)-(\sum_{i=1}^n(i-1)p_i)\\&=n(\sum_{i=1}^np_i)-(\sum_{i=1}^n(i-1)p_i)\end{aligned}

那么,我们需要维护两棵树状数组,一颗存储 \sum_{i=1}^np_i(我们称为 b),一颗存储 \sum_{i=1}^n(i-1)p_i(我们称为 c),那么在每次修改的时候:

b.add(l,k);
b.add(r+1,-k);
c.add(l,(l-1)*k);
c.add(r+1,-r*k);

每次查询的结果就是:

\begin{aligned}\sum_{i=l}^{r}a_i&=(\sum_{i=1}^ra_i)-(\sum_{i=1}^{l-1}a_i)\\&=(r(\sum_{i=1}^rp_i)-(\sum_{i=1}^r(i-1)p_i))-((l-1)(\sum_{i=1}^{l-1}p_i)-(\sum_{i=1}^{l-1}(i-1)p_i))\end{aligned}
int sumr=r*b.getsum(r)-c.getsum(r),suml=(l-1)*b.getsum(l-1)-c.getsum(l-1);
write(sumr-suml);

总体复杂度 O(n\log n)
代码:

#include<bits/stdc++.h>
#define int long long
#define inl inline
#define INF 214748364721474836
#define rep(i,x,y) for(int i=x;i<=y;++i) 
using namespace std;
inl int read(){
    int f=1, x=0;
    char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return f*x;
}
inl void write(int x){
    if(x<0){
        putchar('-');
        x=-x;
    }
    if(x>=10) write(x/10);
    putchar(x%10^48);
    return;
}
const int N=5e5+5;
int n,m,a[N],opt,l,r,k;
struct BIT{
    int p[N];
    inl int lowbit(int x){
        return x&(-x);
    }
    inl void add(int x,int k){
        while(x<=n){
            p[x]+=k;
            x+=lowbit(x);
        }
        return;
    }
    inl int getsum(int x){
        int sum=0;
        while(x){
            sum+=p[x];
            x-=lowbit(x);
        }
        return sum;
    }
}b,c;
signed main(){
    n=read();
    m=read();
    rep(i,1,n){
        a[i]=read();
        b.add(i,a[i]-a[i-1]);
        c.add(i,(i-1)*(a[i]-a[i-1]));
    }
    rep(i,1,m){
        opt=read();
        if(opt==1){
            l=read();
            r=read();
            k=read();
            b.add(l,k);
            b.add(r+1,-k);
            c.add(l,(l-1)*k);
            c.add(r+1,-r*k);
        }else{
            l=read();
            r=read();
            int sumr=r*b.getsum(r)-c.getsum(r),suml=(l-1)*b.getsum(l-1)-c.getsum(l-1);
            write(sumr-suml);
            puts("");
        }
    }
    return 0;
}