线段树知识总结

· · 个人记录

适用于维护满足结合律的运算,如区间和、区间最大值、区间最大公约数。

难点(精髓)在于lazytag实现的不需要就不修改。

一、区间加

(P3372 【模板】线段树 1)

#include <bits/stdc++.h>
using namespace std;
const int MAX_N=100005;
struct node{
    int l,r;
    long long sum,add;
}tree[MAX_N<<2];
int a[MAX_N],n,m;
void build(int p,int l,int r){
    tree[p].l=l;
    tree[p].r=r;
    if(l==r){
        tree[p].sum=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    tree[p].sum=tree[p*2].sum+tree[p*2+1].sum;
}
void spread(int p){
    tree[p*2].sum+=tree[p].add*(tree[p*2].r-tree[p*2].l+1);
    tree[p*2+1].sum+=tree[p].add*(tree[p*2+1].r-tree[p*2+1].l+1);
    tree[p*2].add+=tree[p].add;
    tree[p*2+1].add+=tree[p].add;
    tree[p].add=0;
}
void change(int p,int x,int y,int k){
    if(x<=tree[p].l && y>=tree[p].r){
        tree[p].add+=k;
        tree[p].sum+=k*(tree[p].r-tree[p].l+1);
        return;
    }
    if(tree[p].add){
        spread(p);
    }
    int mid=(tree[p].l+tree[p].r)>>1;
    if(x<=mid){
        change(p*2,x,y,k);
    }
    if(y>mid){
        change(p*2+1,x,y,k);
    }
    tree[p].sum=tree[p*2].sum+tree[p*2+1].sum;
}
long long ask(int p,int x,int y){
    if(x<=tree[p].l && y>=tree[p].r){
        return tree[p].sum;
    }
    long long ans=0;
    int mid=(tree[p].l+tree[p].r)>>1;
    if(tree[p].add){
        spread(p);
    }
    if(x<=mid){
        ans+=ask(p*2,x,y);
    }
    if(y>mid){
        ans+=ask(p*2+1,x,y);
    }
    return ans;
}
int main(){
    scanf("%d %d",&n,&m);
    for(int i=1;i<=n;++i){
        scanf("%d",&a[i]);
    }
    build(1,1,n);
    int opt,x,y,k;
    for(int i=1;i<=m;++i){
        scanf("%d",&opt);
        if(opt==1){
            scanf("%d %d %d",&x,&y,&k);
            change(1,x,y,k);
        }
        else{
            scanf("%d %d",&x,&y);
            printf("%lld\n",ask(1,x,y));
        }
    }
    return 0;
}

二、区间乘

注意先乘再加。(乘法分配律)

(P3373 【模板】线段树 2)

#include <bits/stdc++.h>
using namespace std;
const int MAX_N=100005;
struct node{
    int l,r;
    long long sum,mul=1,add;
}t[MAX_N<<2];
int mod,n,m,a[MAX_N];
int ls(int x){
    return x*2;
}
int rs(int x){
    return x*2+1;
}
void build(int p,int l,int r){
    t[p].l=l;
    t[p].r=r;
    if(l==r){
        t[p].sum=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(ls(p),l,mid);
    build(rs(p),mid+1,r);
    t[p].sum=(t[ls(p)].sum+t[rs(p)].sum)%mod;
}
void spread(int p){
    t[ls(p)].sum=((t[ls(p)].sum*t[p].mul)%mod+t[p].add*(t[ls(p)].r-t[ls(p)].l+1)%mod)%mod;
    t[rs(p)].sum=((t[rs(p)].sum*t[p].mul)%mod+t[p].add*(t[rs(p)].r-t[rs(p)].l+1)%mod)%mod;
    t[ls(p)].add=(t[ls(p)].add*t[p].mul+t[p].add)%mod;
    t[rs(p)].add=(t[rs(p)].add*t[p].mul+t[p].add)%mod;
    t[ls(p)].mul=(t[ls(p)].mul*t[p].mul)%mod;
    t[rs(p)].mul=(t[rs(p)].mul*t[p].mul)%mod;
    t[p].add=0;
    t[p].mul=1;
}
void ch1(int p,int x,int y,int k){      //加 
    if(x<=t[p].l && y>=t[p].r){
        t[p].add=(t[p].add+k)%mod;
        t[p].sum=(t[p].sum+k*(t[p].r-t[p].l+1))%mod;
        return;
    }
    spread(p);
    int mid=(t[p].l+t[p].r)>>1;
    if(x<=mid){
        ch1(ls(p),x,y,k);
    }
    if(y>mid){
        ch1(rs(p),x,y,k);
    }
    t[p].sum=(t[ls(p)].sum%mod+t[rs(p)].sum%mod)%mod;
}
void ch2(int p,int x,int y,int k){      //乘 
    if(x<=t[p].l && y>=t[p].r){
        t[p].mul=(t[p].mul*k)%mod;
        t[p].sum=(t[p].sum*k)%mod;
        t[p].add=(t[p].add*k)%mod;
        return;
    }
    spread(p);
    int mid=(t[p].l+t[p].r)>>1;
    if(x<=mid){
        ch2(ls(p),x,y,k);
    }
    if(y>mid){
        ch2(rs(p),x,y,k);
    }
    t[p].sum=(t[ls(p)].sum%mod+t[rs(p)].sum%mod)%mod;
}
long long ask(int p,int x,int y){
    if(x<=t[p].l && y>=t[p].r){
        return t[p].sum;
    }
    long long ans=0;
    int mid=(t[p].l+t[p].r)>>1;
    spread(p);
    if(x<=mid){
        ans=(ans%mod+ask(ls(p),x,y)%mod)%mod;
    }
    if(y>mid){
        ans=(ask(rs(p),x,y)%mod+ans%mod)%mod;
    }
    return ans;
}
int main(){
    scanf("%d %d %d",&n,&m,&mod);
    for(int i=1;i<=n;++i){
        scanf("%d",&a[i]);
    }
    build(1,1,n);
    int opt,x,y,k;
    for(int i=1;i<=m;++i){
        scanf("%d",&opt);
        if(opt==1){
            scanf("%d %d %d",&x,&y,&k);
            ch2(1,x,y,k);
        }
        else if(opt==2){
            scanf("%d %d %d",&x,&y,&k);
            ch1(1,x,y,k);
        }
        else{
            scanf("%d %d",&x,&y);
            printf("%lld\n",ask(1,x,y));
        }
    }
    return 0;
}

三、区间开方

优化:当一个数已经为1时,就不需要再进行操作。所以可以同时维护区间最大值,当区间最大值为1时,不需要操作。

(P4145 上帝造题的七分钟 2 / 花神游历各国)

#include <bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int MAX_N=100005;
long long n,m,a[MAX_N];
struct node{
    int l,r;
    long long sum,maxx;
}t[MAX_N<<2];
void build(int p,int l,int r){
    t[p].l=l;
    t[p].r=r;
    if(l==r){
        t[p].maxx=a[l];
        t[p].sum=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    t[p].sum=t[ls].sum+t[rs].sum;
    t[p].maxx=max(t[ls].maxx,t[rs].maxx);
}
void change(int p,int l,int r){
    if(t[p].l==t[p].r && l<=t[p].l && t[p].r<=r){
        t[p].maxx=t[p].sum=sqrt(t[p].sum);
        return;
    }
    int mid=(t[p].l+t[p].r)>>1;
    if(l<=mid && t[ls].maxx>1){
        change(ls,l,r);
    }
    if(r>mid && t[rs].maxx>1){
        change(rs,l,r);
    }
    t[p].sum=t[ls].sum+t[rs].sum;
    t[p].maxx=max(t[ls].maxx,t[rs].maxx);
}
long long ask(int p,int l,int r){
    if(l<=t[p].l && t[p].r<=r){
        return t[p].sum;
    }
    int mid=(t[p].l+t[p].r)>>1;
    long long ans=0;
    if(l<=mid){
        ans+=ask(ls,l,r);
    }
    if(r>mid){
        ans+=ask(rs,l,r);
    }
    return ans;
}
int main(){
    scanf("%lld",&n);
    for(int i=1;i<=n;++i){
        scanf("%lld",&a[i]);
    }
    build(1,1,n);
    scanf("%lld",&m);
    int k,l,r;
    while(m--){
        scanf("%d %d %d",&k,&l,&r);
        if(l>r){
            swap(l,r);
        }
        if(!k){
            change(1,l,r);
        }
        else{
            printf("%lld\n",ask(1,l,r));
        }
    }
    return 0;
}

四、区间最大值

(P1198 [JSOI2008] 最大数)

#include <bits/stdc++.h>
using namespace std;
const int MAX_M=2e5+5;
int m,d,tot,w;
int t[MAX_M<<2];
void change(int p,int l,int r,int x,int k){
    if(l==x && r==x){
        t[p]=k;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid){
        change(p<<1,l,mid,x,k);
    }
    else{
        change(p<<1|1,mid+1,r,x,k);
    }
    t[p]=max(t[p<<1],t[p<<1|1]);
}
int ask(int p,int l,int r,int x,int y){
    if(x<=l && r<=y){
        return t[p];
    }
    int mid=(l+r)>>1;
    int ans=-0x7f;
    if(x<=mid){
        ans=max(ans,ask(p<<1,l,mid,x,y));
    }
    if(y>mid){
        ans=max(ans,ask(p<<1|1,mid+1,r,x,y));
    }
    return ans;
}
int main(){
    scanf("%d %d",&m,&d);
    char opt;
    int n;
    for(int i=1;i<=m;++i){
        cin>>opt>>n;
        if(opt=='Q'){
            w=ask(1,1,m,tot-n+1,tot);       //只能用1到m做第一个数的l和r 
            printf("%d\n",w);               //因为tot在变 
        }
        else{
            tot++;
            n+=w;
            n%=d;
            change(1,1,m,tot,n);
        }
    }
    return 0;
} 

五、区间异或

(P2574 XOR的艺术)

#include <bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int MAX_N=200005;
int a[MAX_N],n,m,tot;
struct node{
    int l,r,cnt0,cnt1,tag;
}t[MAX_N<<2];
void build(int p,int l,int r){
    t[p].l=l;
    t[p].r=r;
    if(l==r){
        if(a[l]){
            t[p].cnt1=1;
        }
        else{
            t[p].cnt0=1;
        }
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    t[p].cnt0=t[ls].cnt0+t[rs].cnt0;
    t[p].cnt1=t[ls].cnt1+t[rs].cnt1;
}
void spread(int p){
    if(t[p].tag){
        t[ls].tag^=1;
        t[rs].tag^=1;
        swap(t[ls].cnt0,t[ls].cnt1);
        swap(t[rs].cnt0,t[rs].cnt1);
        t[p].tag=0;
    }
}
void change(int p,int l,int r){
    if(l<=t[p].l && t[p].r<=r){
        t[p].tag^=1;
        swap(t[p].cnt0,t[p].cnt1);
        return;
    }
    spread(p);
    int mid=(t[p].l+t[p].r)>>1;
    if(l<=mid){
        change(ls,l,r);
    }
    if(r>mid){
        change(rs,l,r);
    }
    t[p].cnt0=t[ls].cnt0+t[rs].cnt0;
    t[p].cnt1=t[ls].cnt1+t[rs].cnt1;
}
int ask(int p,int l,int r){
    if(l<=t[p].l && t[p].r<=r){
        return t[p].cnt1;
    }
    int ans=0;
    spread(p);
    int mid=(t[p].l+t[p].r)>>1;
    if(l<=mid){
        ans+=ask(ls,l,r);
    }
    if(r>mid){
        ans+=ask(rs,l,r);
    }
    return ans;
}
int main(){
    scanf("%d %d",&n,&m);
    for(int i=1;i<=n;++i){
        scanf("%1d",&a[i]);
    }
    build(1,1,n);
    int op,x,y;
    while(m--){
        scanf("%d %d %d",&op,&x,&y);
        if(op==1){
            printf("%d\n",ask(1,x,y));
        }
        else{
            change(1,x,y);
        }
    }
    return 0;
}

六、权值线段树

每个节点维护一个区间的数出现的次数。

动态开点和线段树合并

动态开点:不创建无关的节点。

线段树合并:将两颗线段树对应位置的节点的值合在一起,创建一颗新的线段树。

例题1:U41492 树上数颜色

#include <bits/stdc++.h>
using namespace std;
const int N=100005;
int n,m,tot,hd[N],k,ans[N],a[N],rt[N],b[N],g;
struct node{
    int next,to;
}edge[N*2];
struct tree{
    int ls,rs,cnt;
}t[N*24];
void add(int u,int v){
    edge[++tot].to=v;
    edge[tot].next=hd[u];
    hd[u]=tot;
}
int build(int p,int l,int r,int w){
    if(!p){
        p=++k;
    }
    if(l==r){
        t[p].cnt=1;
        return p;
    }
    int mid=l+r>>1;
    if(w<=mid){
        t[p].ls=build(t[p].ls,l,mid,w);
    }
    else{
        t[p].rs=build(t[p].rs,mid+1,r,w);
    }
    t[p].cnt=t[t[p].ls].cnt+t[t[p].rs].cnt;
    return p;
}
int merge(int p,int q,int l,int r){
    if(!p) return q;
    if(!q) return p;
    if(l==r){
        t[p].cnt|=t[q].cnt;
        return p;
    }
    int mid=l+r>>1;
    t[p].ls=merge(t[p].ls,t[q].ls,l,mid);
    t[p].rs=merge(t[p].rs,t[q].rs,mid+1,r);
    t[p].cnt=t[t[p].ls].cnt+t[t[p].rs].cnt;
    return p;
}
void dfs(int u,int fa){
    rt[u]=build(rt[u],1,g,a[u]);
    for(int i=hd[u];i;i=edge[i].next){
        int v=edge[i].to;
        if(v==fa) continue;
        dfs(v,u);
        rt[u]=merge(rt[u],rt[v],1,g);
    }
    ans[u]=t[rt[u]].cnt;
}
int main(){
    scanf("%d",&n);
    int u,v;
    for(int i=1;i<n;++i){
        scanf("%d %d",&u,&v);
        add(u,v);
        add(v,u);
    }
    for(int i=1;i<=n;++i){
        scanf("%d",&a[i]);
        b[i]=a[i];
    }
    sort(b+1,b+1+n);
    g=unique(b+1,b+1+n)-b-1;
    for(int i=1;i<=n;++i){
        a[i]=lower_bound(b+1,b+1+g,a[i])-b;
    }
    dfs(1,0);
    ans[1]=t[1].cnt;
    scanf("%d",&m);
    int x;
    for(int i=1;i<=m;++i){
        scanf("%d",&x);
        printf("%d\n",ans[x]);
    }
    return 0;
}

例题2:P3521[POI2011]ROT-Tree Rotations

#include <bits/stdc++.h>
using namespace std;
const int N=2e5+5;
int tot,n;
long long u,v,ans;
struct node{
    int ls,rs,s;
}t[N*20];
int build(int p,int l,int r,int w){
    if(!p){
        p=++tot;
    }
    if(l==r){
        t[p].s=1;
        return p;
    }
    int mid=l+r>>1;
    if(w<=mid){
        t[p].ls=build(t[p].ls,l,mid,w);
    }
    else{
        t[p].rs=build(t[p].rs,mid+1,r,w);
    }
    t[p].s=t[t[p].ls].s+t[t[p].rs].s;
    return p;
}
int merge(int p,int q,int l,int r){
    if(!p) return q;
    if(!q) return p;
    if(l==r){
        t[p].s+=t[q].s;
        return p;
    }
    int mid=l+r>>1;
    u+=(long long)t[t[p].rs].s*(long long)t[t[q].ls].s;
    v+=(long long)t[t[p].ls].s*(long long)t[t[q].rs].s;
    t[p].ls=merge(t[p].ls,t[q].ls,l,mid);
    t[p].rs=merge(t[p].rs,t[q].rs,mid+1,r);
    t[p].s=t[t[p].ls].s+t[t[p].rs].s;
    return p;
}
int dfs(){
    int y,x;
    scanf("%d",&x);
    if(x){
        y=build(0,1,n,x);
    }
    else{
        int ls=dfs(),rs=dfs();
        y=merge(ls,rs,1,n);
        ans+=min(u,v);
        u=v=0;
    }
    return y;
}
int main(){
    scanf("%d", &n);
    dfs();
    printf("%lld",ans);
    return 0;
}

例题3:P4556 [Vani有约会]雨天的尾巴 /【模板】线段树合并

注意空间复杂度

#include <bits/stdc++.h>
using namespace std;
const int N=100005;
int n,m,tot,hd[N],f[N][25],dep[N],ans[N];
int a[N],b[N],c[N],len,cnt,rt[N];
struct node{
    int next,to;
}edge[N<<1];
void add(int u,int v){
    edge[++tot]=node{hd[u],v};
    hd[u]=tot;
}
void dfs1(int x,int fa){
    dep[x]=dep[fa]+1;
    for(int i=hd[x];i;i=edge[i].next){
        int y=edge[i].to;
        if(fa==y) continue;
        f[y][0]=x;
        for(int j=1;j<=20;++j){
            f[y][j]=f[f[y][j-1]][j-1];
        }
        dfs1(y,x);
    }
}
int LCA(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;--i){
        if(dep[f[x][i]]>=dep[y]) x=f[x][i];
    }
    if(x==y) return x;
    for(int i=20;i>=0;--i){
        if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    }
    return f[x][0];
}
struct T{
    struct tree{
        int ls,rs,s,id;
    }t[N*64];
    void push_up(int p){
        if(t[t[p].ls].s>=t[t[p].rs].s) t[p].s=t[t[p].ls].s,t[p].id=t[t[p].ls].id;
        else t[p].s=t[t[p].rs].s,t[p].id=t[t[p].rs].id;
    }
    int build(int p,int l,int r,int w,int op){
        if(!p) p=++cnt;
        if(l==r){
            t[p].s+=op;
            t[p].id=w;
            return p;
        }
        int mid=l+r>>1;
        if(w<=mid) t[p].ls=build(t[p].ls,l,mid,w,op);
        else t[p].rs=build(t[p].rs,mid+1,r,w,op);
        push_up(p);
        return p;
    }
    int merge(int p,int q,int l,int r){
        if(!p || !q) return p|q;
        if(l==r){
            t[p].s+=t[q].s;
            return p;
        }
        int mid=l+r>>1;
        t[p].ls=merge(t[p].ls,t[q].ls,l,mid);
        t[p].rs=merge(t[p].rs,t[q].rs,mid+1,r);
        push_up(p);
        return p;
    }
}tr;
void dfs2(int x,int fa){
    for(int i=hd[x];i;i=edge[i].next){
        int y=edge[i].to;
        if(y==fa) continue;
        dfs2(y,x);
        rt[x]=tr.merge(rt[x],rt[y],1,len);
    }
    ans[x]=tr.t[rt[x]].id;
    if(tr.t[rt[x]].s==0) ans[x]=0;
}
int main(){
    scanf("%d %d",&n,&m);
    int x,y;
    for(int i=1;i<n;++i){
        scanf("%d %d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs1(1,1);
    for(int i=1;i<=m;++i){
        scanf("%d %d %d",&a[i],&b[i],&c[i]);
        len=max(len,c[i]);
    }
    for(int i=1;i<=m;++i){
        int t=LCA(a[i],b[i]);
        rt[a[i]]=tr.build(rt[a[i]],1,len,c[i],1);
        rt[b[i]]=tr.build(rt[b[i]],1,len,c[i],1);
        rt[t]=tr.build(rt[t],1,len,c[i],-1);
        if(f[t][0]) rt[f[t][0]]=tr.build(rt[f[t][0]],1,len,c[i],-1);
    }
    dfs2(1,1);
    for(int i=1;i<=n;++i) printf("%d\n",ans[i]);
    return 0;
}