线段树知识总结
适用于维护满足结合律的运算,如区间和、区间最大值、区间最大公约数。
难点(精髓)在于lazytag实现的不需要就不修改。
一、区间加
(P3372 【模板】线段树 1)
#include <bits/stdc++.h>
using namespace std;
const int MAX_N=100005;
struct node{
int l,r;
long long sum,add;
}tree[MAX_N<<2];
int a[MAX_N],n,m;
void build(int p,int l,int r){
tree[p].l=l;
tree[p].r=r;
if(l==r){
tree[p].sum=a[l];
return;
}
int mid=(l+r)>>1;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
tree[p].sum=tree[p*2].sum+tree[p*2+1].sum;
}
void spread(int p){
tree[p*2].sum+=tree[p].add*(tree[p*2].r-tree[p*2].l+1);
tree[p*2+1].sum+=tree[p].add*(tree[p*2+1].r-tree[p*2+1].l+1);
tree[p*2].add+=tree[p].add;
tree[p*2+1].add+=tree[p].add;
tree[p].add=0;
}
void change(int p,int x,int y,int k){
if(x<=tree[p].l && y>=tree[p].r){
tree[p].add+=k;
tree[p].sum+=k*(tree[p].r-tree[p].l+1);
return;
}
if(tree[p].add){
spread(p);
}
int mid=(tree[p].l+tree[p].r)>>1;
if(x<=mid){
change(p*2,x,y,k);
}
if(y>mid){
change(p*2+1,x,y,k);
}
tree[p].sum=tree[p*2].sum+tree[p*2+1].sum;
}
long long ask(int p,int x,int y){
if(x<=tree[p].l && y>=tree[p].r){
return tree[p].sum;
}
long long ans=0;
int mid=(tree[p].l+tree[p].r)>>1;
if(tree[p].add){
spread(p);
}
if(x<=mid){
ans+=ask(p*2,x,y);
}
if(y>mid){
ans+=ask(p*2+1,x,y);
}
return ans;
}
int main(){
scanf("%d %d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
build(1,1,n);
int opt,x,y,k;
for(int i=1;i<=m;++i){
scanf("%d",&opt);
if(opt==1){
scanf("%d %d %d",&x,&y,&k);
change(1,x,y,k);
}
else{
scanf("%d %d",&x,&y);
printf("%lld\n",ask(1,x,y));
}
}
return 0;
}
二、区间乘
注意先乘再加。(乘法分配律)
(P3373 【模板】线段树 2)
#include <bits/stdc++.h>
using namespace std;
const int MAX_N=100005;
struct node{
int l,r;
long long sum,mul=1,add;
}t[MAX_N<<2];
int mod,n,m,a[MAX_N];
int ls(int x){
return x*2;
}
int rs(int x){
return x*2+1;
}
void build(int p,int l,int r){
t[p].l=l;
t[p].r=r;
if(l==r){
t[p].sum=a[l];
return;
}
int mid=(l+r)>>1;
build(ls(p),l,mid);
build(rs(p),mid+1,r);
t[p].sum=(t[ls(p)].sum+t[rs(p)].sum)%mod;
}
void spread(int p){
t[ls(p)].sum=((t[ls(p)].sum*t[p].mul)%mod+t[p].add*(t[ls(p)].r-t[ls(p)].l+1)%mod)%mod;
t[rs(p)].sum=((t[rs(p)].sum*t[p].mul)%mod+t[p].add*(t[rs(p)].r-t[rs(p)].l+1)%mod)%mod;
t[ls(p)].add=(t[ls(p)].add*t[p].mul+t[p].add)%mod;
t[rs(p)].add=(t[rs(p)].add*t[p].mul+t[p].add)%mod;
t[ls(p)].mul=(t[ls(p)].mul*t[p].mul)%mod;
t[rs(p)].mul=(t[rs(p)].mul*t[p].mul)%mod;
t[p].add=0;
t[p].mul=1;
}
void ch1(int p,int x,int y,int k){ //加
if(x<=t[p].l && y>=t[p].r){
t[p].add=(t[p].add+k)%mod;
t[p].sum=(t[p].sum+k*(t[p].r-t[p].l+1))%mod;
return;
}
spread(p);
int mid=(t[p].l+t[p].r)>>1;
if(x<=mid){
ch1(ls(p),x,y,k);
}
if(y>mid){
ch1(rs(p),x,y,k);
}
t[p].sum=(t[ls(p)].sum%mod+t[rs(p)].sum%mod)%mod;
}
void ch2(int p,int x,int y,int k){ //乘
if(x<=t[p].l && y>=t[p].r){
t[p].mul=(t[p].mul*k)%mod;
t[p].sum=(t[p].sum*k)%mod;
t[p].add=(t[p].add*k)%mod;
return;
}
spread(p);
int mid=(t[p].l+t[p].r)>>1;
if(x<=mid){
ch2(ls(p),x,y,k);
}
if(y>mid){
ch2(rs(p),x,y,k);
}
t[p].sum=(t[ls(p)].sum%mod+t[rs(p)].sum%mod)%mod;
}
long long ask(int p,int x,int y){
if(x<=t[p].l && y>=t[p].r){
return t[p].sum;
}
long long ans=0;
int mid=(t[p].l+t[p].r)>>1;
spread(p);
if(x<=mid){
ans=(ans%mod+ask(ls(p),x,y)%mod)%mod;
}
if(y>mid){
ans=(ask(rs(p),x,y)%mod+ans%mod)%mod;
}
return ans;
}
int main(){
scanf("%d %d %d",&n,&m,&mod);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
}
build(1,1,n);
int opt,x,y,k;
for(int i=1;i<=m;++i){
scanf("%d",&opt);
if(opt==1){
scanf("%d %d %d",&x,&y,&k);
ch2(1,x,y,k);
}
else if(opt==2){
scanf("%d %d %d",&x,&y,&k);
ch1(1,x,y,k);
}
else{
scanf("%d %d",&x,&y);
printf("%lld\n",ask(1,x,y));
}
}
return 0;
}
三、区间开方
优化:当一个数已经为1时,就不需要再进行操作。所以可以同时维护区间最大值,当区间最大值为1时,不需要操作。
(P4145 上帝造题的七分钟 2 / 花神游历各国)
#include <bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int MAX_N=100005;
long long n,m,a[MAX_N];
struct node{
int l,r;
long long sum,maxx;
}t[MAX_N<<2];
void build(int p,int l,int r){
t[p].l=l;
t[p].r=r;
if(l==r){
t[p].maxx=a[l];
t[p].sum=a[l];
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
t[p].sum=t[ls].sum+t[rs].sum;
t[p].maxx=max(t[ls].maxx,t[rs].maxx);
}
void change(int p,int l,int r){
if(t[p].l==t[p].r && l<=t[p].l && t[p].r<=r){
t[p].maxx=t[p].sum=sqrt(t[p].sum);
return;
}
int mid=(t[p].l+t[p].r)>>1;
if(l<=mid && t[ls].maxx>1){
change(ls,l,r);
}
if(r>mid && t[rs].maxx>1){
change(rs,l,r);
}
t[p].sum=t[ls].sum+t[rs].sum;
t[p].maxx=max(t[ls].maxx,t[rs].maxx);
}
long long ask(int p,int l,int r){
if(l<=t[p].l && t[p].r<=r){
return t[p].sum;
}
int mid=(t[p].l+t[p].r)>>1;
long long ans=0;
if(l<=mid){
ans+=ask(ls,l,r);
}
if(r>mid){
ans+=ask(rs,l,r);
}
return ans;
}
int main(){
scanf("%lld",&n);
for(int i=1;i<=n;++i){
scanf("%lld",&a[i]);
}
build(1,1,n);
scanf("%lld",&m);
int k,l,r;
while(m--){
scanf("%d %d %d",&k,&l,&r);
if(l>r){
swap(l,r);
}
if(!k){
change(1,l,r);
}
else{
printf("%lld\n",ask(1,l,r));
}
}
return 0;
}
四、区间最大值
(P1198 [JSOI2008] 最大数)
#include <bits/stdc++.h>
using namespace std;
const int MAX_M=2e5+5;
int m,d,tot,w;
int t[MAX_M<<2];
void change(int p,int l,int r,int x,int k){
if(l==x && r==x){
t[p]=k;
return;
}
int mid=(l+r)>>1;
if(x<=mid){
change(p<<1,l,mid,x,k);
}
else{
change(p<<1|1,mid+1,r,x,k);
}
t[p]=max(t[p<<1],t[p<<1|1]);
}
int ask(int p,int l,int r,int x,int y){
if(x<=l && r<=y){
return t[p];
}
int mid=(l+r)>>1;
int ans=-0x7f;
if(x<=mid){
ans=max(ans,ask(p<<1,l,mid,x,y));
}
if(y>mid){
ans=max(ans,ask(p<<1|1,mid+1,r,x,y));
}
return ans;
}
int main(){
scanf("%d %d",&m,&d);
char opt;
int n;
for(int i=1;i<=m;++i){
cin>>opt>>n;
if(opt=='Q'){
w=ask(1,1,m,tot-n+1,tot); //只能用1到m做第一个数的l和r
printf("%d\n",w); //因为tot在变
}
else{
tot++;
n+=w;
n%=d;
change(1,1,m,tot,n);
}
}
return 0;
}
五、区间异或
(P2574 XOR的艺术)
#include <bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int MAX_N=200005;
int a[MAX_N],n,m,tot;
struct node{
int l,r,cnt0,cnt1,tag;
}t[MAX_N<<2];
void build(int p,int l,int r){
t[p].l=l;
t[p].r=r;
if(l==r){
if(a[l]){
t[p].cnt1=1;
}
else{
t[p].cnt0=1;
}
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
t[p].cnt0=t[ls].cnt0+t[rs].cnt0;
t[p].cnt1=t[ls].cnt1+t[rs].cnt1;
}
void spread(int p){
if(t[p].tag){
t[ls].tag^=1;
t[rs].tag^=1;
swap(t[ls].cnt0,t[ls].cnt1);
swap(t[rs].cnt0,t[rs].cnt1);
t[p].tag=0;
}
}
void change(int p,int l,int r){
if(l<=t[p].l && t[p].r<=r){
t[p].tag^=1;
swap(t[p].cnt0,t[p].cnt1);
return;
}
spread(p);
int mid=(t[p].l+t[p].r)>>1;
if(l<=mid){
change(ls,l,r);
}
if(r>mid){
change(rs,l,r);
}
t[p].cnt0=t[ls].cnt0+t[rs].cnt0;
t[p].cnt1=t[ls].cnt1+t[rs].cnt1;
}
int ask(int p,int l,int r){
if(l<=t[p].l && t[p].r<=r){
return t[p].cnt1;
}
int ans=0;
spread(p);
int mid=(t[p].l+t[p].r)>>1;
if(l<=mid){
ans+=ask(ls,l,r);
}
if(r>mid){
ans+=ask(rs,l,r);
}
return ans;
}
int main(){
scanf("%d %d",&n,&m);
for(int i=1;i<=n;++i){
scanf("%1d",&a[i]);
}
build(1,1,n);
int op,x,y;
while(m--){
scanf("%d %d %d",&op,&x,&y);
if(op==1){
printf("%d\n",ask(1,x,y));
}
else{
change(1,x,y);
}
}
return 0;
}
六、权值线段树
每个节点维护一个区间的数出现的次数。
动态开点和线段树合并
动态开点:不创建无关的节点。
线段树合并:将两颗线段树对应位置的节点的值合在一起,创建一颗新的线段树。
例题1:U41492 树上数颜色
#include <bits/stdc++.h>
using namespace std;
const int N=100005;
int n,m,tot,hd[N],k,ans[N],a[N],rt[N],b[N],g;
struct node{
int next,to;
}edge[N*2];
struct tree{
int ls,rs,cnt;
}t[N*24];
void add(int u,int v){
edge[++tot].to=v;
edge[tot].next=hd[u];
hd[u]=tot;
}
int build(int p,int l,int r,int w){
if(!p){
p=++k;
}
if(l==r){
t[p].cnt=1;
return p;
}
int mid=l+r>>1;
if(w<=mid){
t[p].ls=build(t[p].ls,l,mid,w);
}
else{
t[p].rs=build(t[p].rs,mid+1,r,w);
}
t[p].cnt=t[t[p].ls].cnt+t[t[p].rs].cnt;
return p;
}
int merge(int p,int q,int l,int r){
if(!p) return q;
if(!q) return p;
if(l==r){
t[p].cnt|=t[q].cnt;
return p;
}
int mid=l+r>>1;
t[p].ls=merge(t[p].ls,t[q].ls,l,mid);
t[p].rs=merge(t[p].rs,t[q].rs,mid+1,r);
t[p].cnt=t[t[p].ls].cnt+t[t[p].rs].cnt;
return p;
}
void dfs(int u,int fa){
rt[u]=build(rt[u],1,g,a[u]);
for(int i=hd[u];i;i=edge[i].next){
int v=edge[i].to;
if(v==fa) continue;
dfs(v,u);
rt[u]=merge(rt[u],rt[v],1,g);
}
ans[u]=t[rt[u]].cnt;
}
int main(){
scanf("%d",&n);
int u,v;
for(int i=1;i<n;++i){
scanf("%d %d",&u,&v);
add(u,v);
add(v,u);
}
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
b[i]=a[i];
}
sort(b+1,b+1+n);
g=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;++i){
a[i]=lower_bound(b+1,b+1+g,a[i])-b;
}
dfs(1,0);
ans[1]=t[1].cnt;
scanf("%d",&m);
int x;
for(int i=1;i<=m;++i){
scanf("%d",&x);
printf("%d\n",ans[x]);
}
return 0;
}
例题2:P3521[POI2011]ROT-Tree Rotations
#include <bits/stdc++.h>
using namespace std;
const int N=2e5+5;
int tot,n;
long long u,v,ans;
struct node{
int ls,rs,s;
}t[N*20];
int build(int p,int l,int r,int w){
if(!p){
p=++tot;
}
if(l==r){
t[p].s=1;
return p;
}
int mid=l+r>>1;
if(w<=mid){
t[p].ls=build(t[p].ls,l,mid,w);
}
else{
t[p].rs=build(t[p].rs,mid+1,r,w);
}
t[p].s=t[t[p].ls].s+t[t[p].rs].s;
return p;
}
int merge(int p,int q,int l,int r){
if(!p) return q;
if(!q) return p;
if(l==r){
t[p].s+=t[q].s;
return p;
}
int mid=l+r>>1;
u+=(long long)t[t[p].rs].s*(long long)t[t[q].ls].s;
v+=(long long)t[t[p].ls].s*(long long)t[t[q].rs].s;
t[p].ls=merge(t[p].ls,t[q].ls,l,mid);
t[p].rs=merge(t[p].rs,t[q].rs,mid+1,r);
t[p].s=t[t[p].ls].s+t[t[p].rs].s;
return p;
}
int dfs(){
int y,x;
scanf("%d",&x);
if(x){
y=build(0,1,n,x);
}
else{
int ls=dfs(),rs=dfs();
y=merge(ls,rs,1,n);
ans+=min(u,v);
u=v=0;
}
return y;
}
int main(){
scanf("%d", &n);
dfs();
printf("%lld",ans);
return 0;
}
例题3:P4556 [Vani有约会]雨天的尾巴 /【模板】线段树合并
注意空间复杂度
#include <bits/stdc++.h>
using namespace std;
const int N=100005;
int n,m,tot,hd[N],f[N][25],dep[N],ans[N];
int a[N],b[N],c[N],len,cnt,rt[N];
struct node{
int next,to;
}edge[N<<1];
void add(int u,int v){
edge[++tot]=node{hd[u],v};
hd[u]=tot;
}
void dfs1(int x,int fa){
dep[x]=dep[fa]+1;
for(int i=hd[x];i;i=edge[i].next){
int y=edge[i].to;
if(fa==y) continue;
f[y][0]=x;
for(int j=1;j<=20;++j){
f[y][j]=f[f[y][j-1]][j-1];
}
dfs1(y,x);
}
}
int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;--i){
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
}
if(x==y) return x;
for(int i=20;i>=0;--i){
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
}
return f[x][0];
}
struct T{
struct tree{
int ls,rs,s,id;
}t[N*64];
void push_up(int p){
if(t[t[p].ls].s>=t[t[p].rs].s) t[p].s=t[t[p].ls].s,t[p].id=t[t[p].ls].id;
else t[p].s=t[t[p].rs].s,t[p].id=t[t[p].rs].id;
}
int build(int p,int l,int r,int w,int op){
if(!p) p=++cnt;
if(l==r){
t[p].s+=op;
t[p].id=w;
return p;
}
int mid=l+r>>1;
if(w<=mid) t[p].ls=build(t[p].ls,l,mid,w,op);
else t[p].rs=build(t[p].rs,mid+1,r,w,op);
push_up(p);
return p;
}
int merge(int p,int q,int l,int r){
if(!p || !q) return p|q;
if(l==r){
t[p].s+=t[q].s;
return p;
}
int mid=l+r>>1;
t[p].ls=merge(t[p].ls,t[q].ls,l,mid);
t[p].rs=merge(t[p].rs,t[q].rs,mid+1,r);
push_up(p);
return p;
}
}tr;
void dfs2(int x,int fa){
for(int i=hd[x];i;i=edge[i].next){
int y=edge[i].to;
if(y==fa) continue;
dfs2(y,x);
rt[x]=tr.merge(rt[x],rt[y],1,len);
}
ans[x]=tr.t[rt[x]].id;
if(tr.t[rt[x]].s==0) ans[x]=0;
}
int main(){
scanf("%d %d",&n,&m);
int x,y;
for(int i=1;i<n;++i){
scanf("%d %d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(1,1);
for(int i=1;i<=m;++i){
scanf("%d %d %d",&a[i],&b[i],&c[i]);
len=max(len,c[i]);
}
for(int i=1;i<=m;++i){
int t=LCA(a[i],b[i]);
rt[a[i]]=tr.build(rt[a[i]],1,len,c[i],1);
rt[b[i]]=tr.build(rt[b[i]],1,len,c[i],1);
rt[t]=tr.build(rt[t],1,len,c[i],-1);
if(f[t][0]) rt[f[t][0]]=tr.build(rt[f[t][0]],1,len,c[i],-1);
}
dfs2(1,1);
for(int i=1;i<=n;++i) printf("%d\n",ans[i]);
return 0;
}