题解 P3373 【【模板】线段树 2】

· · 题解

这个题非常好想,只需要把一般的线段树的延迟记号增加一个乘法运算即可。

下面从一次函数复合的角度来看待此题。

spread操作相当于是:

几点注意:

  1. mul的初值是1,不是0.

  2. 要用位移运算加速

  3. 善用define可以让程序看起来很直观,甚至不用写注释了

  4. 运算不妨分布进行,看着舒服

#include<bits/stdc++.h>
#define For(i, x) for(register int i=1; i<=x; i++)
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define add(x) tree[x].add
#define mul(x) tree[x].mul
#define mid(x) ((l(x)+r(x))/2)
#define len(x) (r(x)-l(x)+1)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
using namespace std;
typedef long long ll;
inline ll read(){
    ll x=0; ll sign=1; char c=getchar();
    while(c>'9' || c<'0') {if (c=='-') sign=-1;c=getchar();}
    while(c>='0' && c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
    return x*sign;
}
const int N=102000;
ll a[N];
struct Segment_Tree{
    int l, r;
    ll sum, add, mul;
}tree[N<<2];

int n, m, mod;

void build(int p, int l, int r){
    l(p)=l, r(p)=r, mul(p)=1;
    if (l==r) {
        sum(p)=a[l]; 
        sum(p)%=mod;
        return;
    }
    int mid=(l+r)/2;
    build(lson(p), l, mid);
    build(rson(p), mid+1, r);
    sum(p)=sum(lson(p))+sum(rson(p));
    sum(p)%=mod;
}

void spread(int p){
    if (add(p)==0 && mul(p)==1) return;

    sum(lson(p))*=mul(p);
    sum(lson(p))+=add(p)*len(lson(p));
    sum(lson(p))%=mod;
    add(lson(p))=add(lson(p))*mul(p)+add(p);
    add(lson(p))%=mod;
    mul(lson(p))*=mul(p);
    mul(lson(p))%=mod;

    sum(rson(p))*=mul(p);
    sum(rson(p))+=add(p)*len(rson(p));
    sum(rson(p))%=mod;
    add(rson(p))=add(rson(p))*mul(p)+add(p);
    add(rson(p))%=mod;
    mul(rson(p))*=mul(p);
    mul(rson(p))%=mod;

    add(p)=0;
    mul(p)=1;
}

void change(int p, int l, int r, ll d, int op){
    if (l<=l(p) && r(p)<=r){
        if (op==2){
            sum(p)+=(ll)d*len(p);
            add(p)+=d;
        } else{
            sum(p)*=d;
            add(p)*=d;
            mul(p)*=d;
        }
        sum(p)%=mod;
        add(p)%=mod;
        mul(p)%=mod;
        return;
    }
    spread(p);
    if (l<=mid(p))   change(lson(p), l, r, d, op);
    if (mid(p)+1<=r) change(rson(p), l, r, d, op);
    sum(p)=sum(lson(p))+sum(rson(p));
    sum(p)%=mod;
}

ll ask(int p, int l, int r){
    if (l<=l(p) && r(p)<=r) return sum(p);
    spread(p);
    ll val=0;
    if (l<=mid(p))   val+=ask(lson(p), l, r);
    val%=mod;
    if (mid(p)+1<=r) val+=ask(rson(p), l, r);
    val%=mod; val+=mod; val%=mod;
    return val;
}

int main(){
    n=read(), m=read(), mod=read();
    For(i, n) a[i]=read();
    build(1, 1, n);
    For(i, m){
        int w=read(), x=read(), y=read();
        if (w!=3){
            ll z=read();
            change(1, x, y, z, w);
        } else 
            printf("%lld\n", ask(1, x, y));
    }
    return 0;
}