题解 P3373 【【模板】线段树 2】
EricWay1024 · · 题解
这个题非常好想,只需要把一般的线段树的延迟记号增加一个乘法运算即可。
下面从一次函数复合的角度来看待此题。
spread操作相当于是:
-
p的儿子节点原来的延迟记号是mx+n,表示把每个a[i]乘m后加n,记为f(x)=mx+n.
-
而今p的延迟记号是kx+b,记为g(x)=kx+b.
-
把p的延迟记号传递给它的儿子。那么对于p的儿子而言,相当于需要计算g(f(x)), 也就是:
-
g(f(x))=g(mx+n)=k(mx+n)+b=(km)x+(kn+b),表明p的儿子要把原来的mul乘p的mul,把自己的add乘p的mul后再加p的add。
几点注意:
-
mul的初值是1,不是0.
-
要用位移运算加速
-
善用define可以让程序看起来很直观,甚至不用写注释了
-
运算不妨分布进行,看着舒服
#include<bits/stdc++.h>
#define For(i, x) for(register int i=1; i<=x; i++)
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define add(x) tree[x].add
#define mul(x) tree[x].mul
#define mid(x) ((l(x)+r(x))/2)
#define len(x) (r(x)-l(x)+1)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
using namespace std;
typedef long long ll;
inline ll read(){
ll x=0; ll sign=1; char c=getchar();
while(c>'9' || c<'0') {if (c=='-') sign=-1;c=getchar();}
while(c>='0' && c<='9'){x=(x<<3)+(x<<1)+c-'0';c=getchar();}
return x*sign;
}
const int N=102000;
ll a[N];
struct Segment_Tree{
int l, r;
ll sum, add, mul;
}tree[N<<2];
int n, m, mod;
void build(int p, int l, int r){
l(p)=l, r(p)=r, mul(p)=1;
if (l==r) {
sum(p)=a[l];
sum(p)%=mod;
return;
}
int mid=(l+r)/2;
build(lson(p), l, mid);
build(rson(p), mid+1, r);
sum(p)=sum(lson(p))+sum(rson(p));
sum(p)%=mod;
}
void spread(int p){
if (add(p)==0 && mul(p)==1) return;
sum(lson(p))*=mul(p);
sum(lson(p))+=add(p)*len(lson(p));
sum(lson(p))%=mod;
add(lson(p))=add(lson(p))*mul(p)+add(p);
add(lson(p))%=mod;
mul(lson(p))*=mul(p);
mul(lson(p))%=mod;
sum(rson(p))*=mul(p);
sum(rson(p))+=add(p)*len(rson(p));
sum(rson(p))%=mod;
add(rson(p))=add(rson(p))*mul(p)+add(p);
add(rson(p))%=mod;
mul(rson(p))*=mul(p);
mul(rson(p))%=mod;
add(p)=0;
mul(p)=1;
}
void change(int p, int l, int r, ll d, int op){
if (l<=l(p) && r(p)<=r){
if (op==2){
sum(p)+=(ll)d*len(p);
add(p)+=d;
} else{
sum(p)*=d;
add(p)*=d;
mul(p)*=d;
}
sum(p)%=mod;
add(p)%=mod;
mul(p)%=mod;
return;
}
spread(p);
if (l<=mid(p)) change(lson(p), l, r, d, op);
if (mid(p)+1<=r) change(rson(p), l, r, d, op);
sum(p)=sum(lson(p))+sum(rson(p));
sum(p)%=mod;
}
ll ask(int p, int l, int r){
if (l<=l(p) && r(p)<=r) return sum(p);
spread(p);
ll val=0;
if (l<=mid(p)) val+=ask(lson(p), l, r);
val%=mod;
if (mid(p)+1<=r) val+=ask(rson(p), l, r);
val%=mod; val+=mod; val%=mod;
return val;
}
int main(){
n=read(), m=read(), mod=read();
For(i, n) a[i]=read();
build(1, 1, n);
For(i, m){
int w=read(), x=read(), y=read();
if (w!=3){
ll z=read();
change(1, x, y, z, w);
} else
printf("%lld\n", ask(1, x, y));
}
return 0;
}