P3373 【模板】线段树 2

· · 题解

话说学线段树这么久现在才做这个题 我好菜啊

作为一个蒟蒻还是决定码一下这个线段树2

线段树plus在标准线段树的基础上额外增加了区间乘法

也就是说需要写出如下的一个数据结构

很明显我们需要线段树,这里的lazytag需要搞两个。

乘法lazytag:

本着对两个标记一视同仁的态度,我们要给加和乘都要整一个lazytag

struct P{
    int add,mul;//乘法标记和加法标记
}tag[N*4];

需要注意的是:乘法标记需要初始化为1

一、标记传递

和一般线段树一样,我们需要有pushdownpushup两个传递函数

其中,pushup是对区间和进行上传,和一般线段树没有区别

void pushup(int x){
    sum[x]=(sum[x<<1]+sum[x<<1|1])%p;
}

pushdown则不同,因为有两个标记我们需要对两个标记进行优先级考虑

乘法标记应当优先下传,理由是先乘法后加法不会影响结果的精度。

  1. 对于sum区间和,我们先对儿子节点区间和×乘法标记+加法标记 ×区间长度
  2. 对于乘法标记,我们对儿子节点直接×乘法标记
  3. 对于加法标记,我们对儿子节点先×乘法标记+加法标记
    void pushdown(int x,int l,int r){
    int ls=x<<1,rs=x<<1|1,mid=(l+r)>>1;
    //左儿子
    sum[ls]=(sum[ls]*tag[x].mul%p+tag[x].add*(mid-l+1)%p)%p;
    tag[ls].mul=(tag[ls].mul*tag[x].mul)%p;
    tag[ls].add=(tag[ls].add*tag[x].mul+tag[x].add)%p;
    //右儿子
    sum[rs]=(sum[rs]*tag[x].mul%p+tag[x].add*(r-(mid+1)+1)%p)%p;
    tag[rs].mul=(tag[rs].mul*tag[x].mul)%p;
    tag[rs].add=(tag[rs].add*tag[x].mul+tag[x].add)%p;
    //标记初始化
    tag[x].add=0;tag[x].mul=1;
    return;
    }

二、加法区间修改

然后我们的区间修改updata也需要写成两个函数

对于加法updata我们要

  1. 对加法标记加上k
  2. 区间和加上 k×ll是区间长度)

其余的写法和标准线段树一模一样

void updata_add(int x,int l,int r,int ll,int rr,int k){
    if(r<ll || l>rr) return;//当前区间在目标区间之外,直接返回
    if(l>=ll && r<=rr){//当前区间被目标区间包含,标记修改
        sum[x]=(sum[x]+k*(r-l+1)%p)%p;
        tag[x].add=(tag[x].add+k)%p;
        return;
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;//对左右儿子进行修改
    updata_add(x<<1,l,mid,ll,rr,k);
    updata_add(x<<1|1,mid+1,r,ll,rr,k);
    pushup(x);
}

三、乘法区间修改

乘法updata其余写法和加法一样,唯一不同是当前区间被查询区间覆盖的时候,我们需要对两个标记都进行修改,没有疑问,全部乘以k

  1. 对乘法标记乘以k
  2. 对加法标记乘以k
  3. 对区间和乘以k
void updata_mul(int x,int l,int r,int ll,int rr,int k){
    if(r<ll || l>rr) return;
    if(l>=ll && r<=rr){
        sum[x]=(sum[x]*k)%p;
        tag[x].mul=(tag[x].mul*k)%p;
        tag[x].add=(tag[x].add*k)%p;
        return;
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    updata_mul(x<<1,l,mid,ll,rr,k);
    updata_mul(x<<1|1,mid+1,r,ll,rr,k);
    pushup(x);
}

四、区间和查询

和标准线段树一模一样,查询的时候注意下传标记。

int query(int x,int l,int r,int ll,int rr){
    if(r<ll || l>rr) return 0;//当前区间在目标区间外,不影响区间和
    if(l>=ll && r<=rr){//当前区间被目标期间包含,直接返回区间和
        return sum[x];
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    return (query(x<<1,l,mid,ll,rr)+query(x<<1|1,mid+1,r,ll,rr))%p;
}

p.s. 还有一些小细节,比如对于p的取模,记得要开longlong,空间开四倍等等... ...

(十年OI一场空,不开longlong见祖宗)

完整代码如下:

#include<bits/stdc++.h>
#define N 100100
#define int long long//第一次写没开longlong的本蒟蒻
using namespace std;
int n,m,p;
int a[N*4];
int sum[N*4];
struct P{
    int add,mul;
}tag[N*4];
void pushup(int x){
    sum[x]=(sum[x<<1]+sum[x<<1|1])%p;
}
void build(int x,int l,int r){//建树
    if(l==r){
        sum[x]=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(x<<1,l,mid);
    build(x<<1|1,mid+1,r);
    pushup(x);
}
void pushdown(int x,int l,int r){
    int ls=x<<1,rs=x<<1|1,mid=(l+r)>>1;
    sum[ls]=(sum[ls]*tag[x].mul%p+tag[x].add*(mid-l+1)%p)%p;
    tag[ls].mul=(tag[ls].mul*tag[x].mul)%p;
    tag[ls].add=(tag[ls].add*tag[x].mul+tag[x].add)%p;

    sum[rs]=(sum[rs]*tag[x].mul%p+tag[x].add*(r-(mid+1)+1)%p)%p;
    tag[rs].mul=(tag[rs].mul*tag[x].mul)%p;
    tag[rs].add=(tag[rs].add*tag[x].mul+tag[x].add)%p;

    tag[x].add=0;tag[x].mul=1;
    return;
}
void updata_mul(int x,int l,int r,int ll,int rr,int k){
    if(r<ll || l>rr) return;
    if(l>=ll && r<=rr){
        sum[x]=(sum[x]*k)%p;
        tag[x].mul=(tag[x].mul*k)%p;
        tag[x].add=(tag[x].add*k)%p;
        return;
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    updata_mul(x<<1,l,mid,ll,rr,k);
    updata_mul(x<<1|1,mid+1,r,ll,rr,k);
    pushup(x);
}
void updata_add(int x,int l,int r,int ll,int rr,int k){
    if(r<ll || l>rr) return;
    if(l>=ll && r<=rr){
        sum[x]=(sum[x]+k*(r-l+1)%p)%p;
        tag[x].add=(tag[x].add+k)%p;
        return;
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    updata_add(x<<1,l,mid,ll,rr,k);
    updata_add(x<<1|1,mid+1,r,ll,rr,k);
    pushup(x);
}
int query(int x,int l,int r,int ll,int rr){
    if(r<ll || l>rr) return 0;
    if(l>=ll && r<=rr){
        return sum[x];
    }
    pushdown(x,l,r);
    int mid=(l+r)>>1;
    return (query(x<<1,l,mid,ll,rr)+query(x<<1|1,mid+1,r,ll,rr))%p;
}
signed main(){
    scanf("%lld %lld %lld",&n,&m,&p);
    for(int i=1;i<=4*n;i++) tag[i].mul=1;//注意乘法标记初始化为1
    for(int i=1;i<=n;i++){
        scanf("%lld",&a[i]);
    }
    build(1,1,n);
    int opt,x,y,k;
    while(m--){
        scanf("%lld",&opt);
        if(opt==1){
            scanf("%lld %lld %lld",&x,&y,&k);
            updata_mul(1,1,n,x,y,k);
        }
        if(opt==2){
            scanf("%lld %lld %lld",&x,&y,&k);
            updata_add(1,1,n,x,y,k);
        }
        if(opt==3){
            scanf("%lld %lld",&x,&y);
            cout<<query(1,1,n,x,y)%p<<endl;
        }

    }
    return 0;
}