题解 P3368 【【模板】树状数组 2】
看到题解里,没有用线段树的(我线段树这么不受人待见吗?!)
于是,我来写一发线段树的题解,很简单就是在线段树1那个板子的基础上修改了一点点
同样是用一棵线段树来维护整个数列,我们只需要做到区间修改和单点查询两个操作。 区间修改?怎么改呢?
inline void addval(int rt){
int l=t[rt].left,r=t[rt].right;
if(a<=l&&r<=b){//判断我们当前是否应该进行修改
t[rt].tag+=x;//先打上懒标记
t[rt].data+=(t[rt].right-t[rt].left+1)*x;
//由于线段树的节点是管理一个区间
//所以我们在修改节点值的时候,应该用区间去乘需要加入的值x
return ;
}
if(t[rt].tag) pushdown(rt);//有标记要先下传标记
if(a<=mid) addval(ls);
//如果要查询的区间有部分在左子树,那么向左子树修改
if(b>mid) addval(rs);
//同上,向右子树修改
update(rt);//同pushup
return ;
}
好了,解决了区间修改,我们要考虑单点查询了! 线段树是以区间查询著称,那么怎么才能做到单点修改呢?
inline void query(int rt){
a=pos,b=pos;//这里是我修改的地方
//我们用线段树修改时,是修改一个区间,它的左右端点分别是a,b
//单点查询,我们把它看作是一个长度为1的区间,左右端点相同
//所以我们把要查询的位置都赋值给a和b,这样就达到了单点查询的目的
int l=t[rt].left,r=t[rt].right;
if(a<=l&&r<=b){//同上一段代码
ans+=t[rt].data;
return ;
}
if(t[rt].tag) pushdown(rt);
if(a<=mid) query(ls);
if(b>mid) query(rs);
return ;
}
接下来附上完整代码:
#include<iostream>
#include<cstdio>
#define ll long long
#define ls (rt<<1)
#define rs (rt<<1|1)
#define mid ((l+r)>>1)
#define update(rt) t[rt].data=t[ls].data+t[rs].data
#define N 500005
//宏定义大法好
using namespace std;
struct tree{
ll tag,data;
int left,right;
}t[(N<<4)+5];//之前因为数组开小了,RE了3个点
int n,m,pos;
ll s[N],ans;
int a,b,x,k;
inline void build(int rt,int l,int r){
//递归建树
t[rt].left=l;t[rt].right=r;t[rt].tag=0;
if(l==r){
t[rt].data=s[l];
return ;
}
build(ls,l,mid);
build(rs,mid+1,r);
update(rt);
return ;
}
inline void pushdown(int rt){//标记下传
t[ls].tag+=t[rt].tag;
t[rs].tag+=t[rt].tag;
t[ls].data+=t[rt].tag*(t[ls].right-t[ls].left+1);
t[rs].data+=t[rt].tag*(t[rs].right-t[rs].left+1);
t[rt].tag=0;
return ;
}
inline void addval(int rt){
//区间修改
int l=t[rt].left,r=t[rt].right;
if(a<=l&&r<=b){
t[rt].tag+=x;
t[rt].data+=(t[rt].right-t[rt].left+1)*x;
return ;
}
if(t[rt].tag) pushdown(rt);
if(a<=mid) addval(ls);
if(b>mid) addval(rs);
update(rt);
return ;
}
inline void query(int rt){
//单点查询
a=pos,b=pos;
int l=t[rt].left,r=t[rt].right;
if(a<=l&&r<=b){
ans+=t[rt].data;
return ;
}
if(t[rt].tag) pushdown(rt);
if(a<=mid) query(ls);
if(b>mid) query(rs);
return ;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) scanf("%lld",&s[i]);
build(1,1,n);
for(int i=1;i<=m;++i){
scanf("%d",&k);
if(k==1){
scanf("%d%d%d",&a,&b,&x);
addval(1);//区间修改
}
else{
scanf("%d",&pos);//要查询的位置
query(1);
printf("%lld\n",ans);
ans=0;
}
}
return 0;
}