线段树

· · 算法·理论

一、查询区间和,维护区间加

#include<bits/stdc++.h>
#define N 100005
#define int long long
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
using namespace std;
struct node
{
    int l,r;
    int sum,add;
}t[N*4];
int n,m;
int a[N];
void update(int id)
{
    t[id].sum=t[ls(id)].sum+t[rs(id)].sum;
}
void build(int id,int l,int r)
{
    t[id].l=l,t[id].r=r;
    if(l==r)
    {
        t[id].sum=a[l];
        return;
    }
    int mid=l+r>>1;
    build(ls(id),l,mid);
    build(rs(id),mid+1,r);
    update(id); 
}
void pushdown(int id)
{
    if(t[id].add)
    {
        t[ls(id)].sum+=t[id].add*(t[ls(id)].r-t[ls(id)].l+1);
        t[rs(id)].sum+=t[id].add*(t[rs(id)].r-t[rs(id)].l+1);
        t[ls(id)].add+=t[id].add;
        t[rs(id)].add+=t[id].add;
        t[id].add=0;
    }
}
void change(int id,int l,int r,int k)
{
    if(l<=t[id].l&&t[id].r<=r)
    {
        t[id].sum+=k*(t[id].r-t[id].l+1);
        t[id].add+=k;
        return; 
    }
    pushdown(id);
    int mid=t[id].l+t[id].r>>1;
    if(l<=mid) change(ls(id),l,r,k);
    if(r>mid) change(rs(id),l,r,k);
    update(id);
}
int query(int id,int l,int r)
{
    if(l<=t[id].l&&t[id].r<=r) return t[id].sum;
    pushdown(id);
    int mid=t[id].l+t[id].r>>1;
    int ans=0;
    if(l<=mid) ans+=query(ls(id),l,r);
    if(r>mid) ans+=query(rs(id),l,r);
    return ans;
}
signed main()
{
    cin>>n>>m;
    for(int i=1;i<=n;++i) cin>>a[i];
    build(1,1,n);
    for(int i=1;i<=m;++i)
    {
        int opt,x,y,k;
        cin>>opt;
        if(opt==1)
        {
            cin>>x>>y>>k;
            change(1,x,y,k);
        }
        else if(opt==2)
        {
            cin>>x>>y;
            cout<<query(1,x,y)<<"\n";
        }
    }
    return 0;
}

二、查询区间和,维护区间加、区间乘

#include<bits/stdc++.h>
#define int long long
#define N 100005
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
using namespace std;
struct node
{
    int l,r;
    int sum,add,mul;
}t[N*4];
int n,q,mod;
int a[N];
void update(int id)
{
    t[id].sum=(t[ls(id)].sum+t[rs(id)].sum)%mod;
}
void build(int id,int l,int r)
{
    t[id].l=l,t[id].r=r,t[id].mul=1;
    if(l==r){t[id].sum=a[l]%mod;return;}
    int mid=l+r>>1;
    build(ls(id),l,mid);
    build(rs(id),mid+1,r);
    update(id);
}
void pushdown(int id)
{
    t[ls(id)].sum=(t[ls(id)].sum*t[id].mul+t[id].add*(t[ls(id)].r-t[ls(id)].l+1))%mod;
    t[rs(id)].sum=(t[rs(id)].sum*t[id].mul+t[id].add*(t[rs(id)].r-t[rs(id)].l+1))%mod;
    t[ls(id)].mul=(t[ls(id)].mul*t[id].mul)%mod;
    t[rs(id)].mul=(t[rs(id)].mul*t[id].mul)%mod;
    t[ls(id)].add=(t[ls(id)].add*t[id].mul+t[id].add)%mod;
    t[rs(id)].add=(t[rs(id)].add*t[id].mul+t[id].add)%mod;
    t[id].add=0,t[id].mul=1;
}
void changeadd(int id,int l,int r,int k)
{
    if(l<=t[id].l&&t[id].r<=r)
    {
        t[id].add=(t[id].add+k)%mod;
        t[id].sum=(t[id].sum+k*(t[id].r-t[id].l+1))%mod;
        return;
    }
    pushdown(id);
    int mid=t[id].l+t[id].r>>1;
    if(l<=mid) changeadd(ls(id),l,r,k);
    if(r>mid) changeadd(rs(id),l,r,k);
    update(id);
}
void changemul(int id,int l,int r,int k)
{
    if(l<=t[id].l&&t[id].r<=r)
    {
        t[id].sum=(t[id].sum*k)%mod;
        t[id].add=(t[id].add*k)%mod;
        t[id].mul=(t[id].mul*k)%mod;
        return;
    }
    pushdown(id);
    int mid=t[id].l+t[id].r>>1;
    if(l<=mid) changemul(ls(id),l,r,k);
    if(r>mid) changemul(rs(id),l,r,k);
    update(id);
}
int query(int id,int l,int r)
{
    if(l<=t[id].l&&t[id].r<=r) return t[id].sum;
    pushdown(id);
    int ans=0;
    int mid=t[id].l+t[id].r>>1;
    if(l<=mid) ans=(ans+query(ls(id),l,r))%mod;
    if(r>mid) ans=(ans+query(rs(id),l,r))%mod;
    return ans;
}
signed main()
{
    cin>>n>>q>>mod;
    for(int i=1;i<=n;++i) cin>>a[i];
    build(1,1,n);
    while(q--)
    {
        int opt,x,y,k;
        cin>>opt;
        if(opt==1)
        {
            cin>>x>>y>>k;
            changemul(1,x,y,k);
        }
        else if(opt==2)
        {
            cin>>x>>y>>k;
            changeadd(1,x,y,k);
        }
        else if(opt==3)
        {
            cin>>x>>y;
            cout<<query(1,x,y)<<"\n";
        }
    }
    return 0;
}