题解:P3373 【模板】线段树 2
xiaolongmei · · 题解
链接
温馨提示:因为作者非常懒,所以与线段树 1 相同的地方会没有任何注释!!!
思路
有了线段树 1 的基础,我们可以直接考虑在线段树 1 的代码中修改,使这个代码能实现
定义
struct inf{ll l,r,sum,la,lm;}t[4*N];/*la=lazy add lm=lazy multiplication(乘法)*/
主函数
变化很少
int main(){
cin>>n>>q>>mod;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(q--){
cin>>op>>x>>y;
if(op==1)cin>>k,update(1,k,0);/*多了一个乘法要乘的数,和加法要加的数*/
else if(op==2)cin>>k,update(1,1,k);/*同上*/
else cout<<query(1)<<"\n";
}
return 0;
}
建树 (build)
几乎没有什么变化
void build(ll p,ll l,ll r){
t[p]={l,r,a[l],0,1};/*乘法标记依次累乘所以用1*/
if(l==r)return;
ll mid=l+r>>1;
build(2*p,l,mid),build(2*p+1,mid+1,r);
pushup(p);
}
计算子节点的 sum
设父节点的懒标记为
代码
void pushdown(ll p){
js(2*p,t[p].lm,t[p].la),js(2*p+1,t[p].lm,t[p].la);/*更新左右节点*/
t[p].la=0,t[p].lm=1;/*乘法不写1就废了*/
}
下面提供更新部分
void js(ll p,ll m,ll a){/*要更新的节点编号和它的父节点的两个标记*/
t[p].sum=(t[p].sum*m+(t[p].r-t[p].l+1)*a)%mod;/*要乘的+要加的*/
t[p].lm=t[p].lm*m%mod;/*更新乘法标记*/
t[p].la=(t[p].la*m+a)%mod;/*更新加法标记*/
}
区间修改
改动不多,直接上代码!!!
void update(ll p,ll m,ll a){
if(t[p].l>y||t[p].r<x)return;/*完全没有交集直接退出*/
if(t[p].l>=x&&t[p].r<=y){js(p,m,a);return;}/*完全覆盖更新这个节点的值*/
pushdown(p);
update(2*p,m,a),update(2*p+1,m,a);
pushup(p);
}
区间查询
改动不多,直接上代码!!!
ll query(ll p){
if(y<t[p].l||x>t[p].r)return 0;/*完全不覆盖退出*/
if(t[p].l>=x&&t[p].r<=y)return t[p].sum;/*完全覆盖,返回*/
pushdown(p);
return (query(2*p)+query(2*p+1))%mod;/*两者和%mod*/
}
还是挺好理解的,送个完整代码吧!
完整代码
懒得再打一边注释了,不懂的上面看看吧!
#include<bits/stdc++.h>
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
#define ll long long
#define N 100005
using namespace std;
struct inf{ll l,r,sum,la,lm;}t[4*N];/*la=lazy add lm=lazy mul....*/
ll n,q,mod,a[N],op,x,y,k;
void pushup(ll p){t[p].sum=(t[2*p].sum+t[2*p+1].sum)%mod;}
void js(ll p,ll m,ll a){
t[p].sum=(t[p].sum*m+(t[p].r-t[p].l+1)*a)%mod;
t[p].lm=t[p].lm*m%mod;
t[p].la=(t[p].la*m+a)%mod;
}
void pushdown(ll p){
js(2*p,t[p].lm,t[p].la),js(2*p+1,t[p].lm,t[p].la);
t[p].la=0,t[p].lm=1;
}
void build(ll p,ll l,ll r){
t[p]={l,r,a[l],0,1};
if(l==r)return;
ll mid=l+r>>1;
build(2*p,l,mid),build(2*p+1,mid+1,r);
pushup(p);
}
void update(ll p,ll m,ll a){
if(t[p].l>y||t[p].r<x)return;
if(t[p].l>=x&&t[p].r<=y){js(p,m,a);return;}
pushdown(p);
update(2*p,m,a),update(2*p+1,m,a);
pushup(p);
}
ll query(ll p){
if(y<t[p].l||x>t[p].r)return 0;
if(t[p].l>=x&&t[p].r<=y)return t[p].sum;
pushdown(p);
return (query(2*p)+query(2*p+1))%mod;
}
int main(){
cin>>n>>q>>mod;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(q--){
cin>>op>>x>>y;
if(op==1)cin>>k,update(1,k,0);
else if(op==2)cin>>k,update(1,1,k);
else cout<<query(1)<<"\n";
}
return 0;
}