线段树的一些应用

· · 个人记录

最近几十天总算把线段树的一些模板题吃透了。

【例 1】:序列,单点加 k,区间求和。(洛谷 P3368)

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=500009;
struct Segment{
    ll l,r,sum;
}tr[4*N];
ll x[N];
void build(ll u,ll l,ll r){
    ll mid=(l+r)/2;
    if(l==r){
        tr[u]=(Segment){l,r,x[l]};
        return;
    }
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    tr[u]=(Segment){l,r,tr[u*2].sum+tr[u*2+1].sum};
}
void add(ll u,ll x,ll delta){
    tr[u].sum+=delta;
    if(tr[u].l==tr[u].r) return; 
    if(x<=tr[u*2].r) add(u*2,x,delta);
    else add(u*2+1,x,delta);
}
ll rsq(ll u,ll l,ll r){
    if(tr[u].l>=l&&tr[u].r<=r) return tr[u].sum;
    else if(tr[u].r<l||tr[u].l>r) return 0;
    else return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
int main(){
    ll n,m;
    cin>>n>>m;
    for(ll i=1;i<=n;i++) cin>>x[i];
    build(1,1,n);
    for(ll i=1;i<=m;i++){
        ll op,x,y;
        cin>>op>>x>>y;
        if(op==1) add(1,x,y);
        else cout<<rsq(1,x,y)<<endl;
    }
    return 0;
}

【例 2】序列,区间加 k,区间求和。(洛谷 P3372)

只需简单 pushdown。代码:

#include<bits/stdc++.h>
using namespace std;
const int N=100010,M=4*N;
int n,m,a[N],b[M];
typedef long long ll;
ll sum[M];
void build(int k,int l,int r){
    if(l==r){
        sum[k]=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(k*2,l,mid);
    build(k<<1|1,mid+1,r);
    sum[k]=sum[k*2]+sum[k<<1|1];
}
void add(int k,int l,int r,int x){
    b[k]+=x;
    sum[k]+=(ll)x*(r-l+1);
}
void pushdown(int k,int l,int r,int mid){
    if(b[k]==0) return;
    add(k*2,l,mid,b[k]);
    add(k<<1|1,mid+1,r,b[k]);
    b[k]=0;
}
ll query(int k,int l,int r,int x,int y){
    if(l>=x && r<=y) return sum[k];
    int mid=(l+r)>>1;
    ll res=0;
    pushdown(k,l,r,mid);
    if(x<=mid) res+=query(k*2,l,mid,x,y);
    if(mid<y) res+=query(k<<1|1,mid+1,r,x,y);
    return res;
}
void modify(int k,int l,int r,int x,int y,int t){
    if(l>=x && r<=y) return add(k,l,r,t);
    int mid=(l+r)>>1;
    pushdown(k,l,r,mid);
    if(x<=mid) modify(k*2,l,mid,x,y,t);
    if(mid<y) modify(k<<1|1,mid+1,r,x,y,t);
    sum[k]=sum[k*2]+sum[k<<1|1];
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    build(1,1,n);
    while(m--){
        int op,x,y,k;
        scanf("%d%d%d",&op,&x,&y);
        if(op==1){
            scanf("%d",&k);
            modify(1,1,n,x,y,k);
        }
        else printf("%lld\n",query(1,1,n,x,y));
    }
    return 0;
}

【例 3】序列,区间加乘 k,区间求和。

注意 pushdown 时先乘后加。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=800009;
ll n,m,MOD;
struct Segment{
    ll l;
    ll r;
    ll len;
    ll sum;
    ll toMul;
    ll toAdd;
};
Segment tr[N*4];
int x[N];
void build(ll u,ll l,ll r){
    if(l==r){
        tr[u]=(Segment){l,r,r-l+1,x[l],1,0};
        return;
    }
    ll m=(l+r)/2; 
    build(u*2,l,m);
    build(u*2+1,m+1,r);
    tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,1,0};
}
void pushdown(ll u){
    if(tr[u].l==tr[u].r) return;
    ll M=tr[u].toMul;
    ll A=tr[u].toAdd;
    tr[u].toMul=1;
    tr[u].toAdd=0;
    (tr[u*2].toMul*=M)%=MOD;
    ((tr[u*2].toAdd*=M)+=A)%=MOD;
    ((tr[u*2].sum*=M)+=tr[u*2].len*A)%=MOD;
    (tr[u*2+1].toMul*=M)%=MOD;
    ((tr[u*2+1].toAdd*=M)+=A)%=MOD;
    ((tr[u*2+1].sum*=M)+=tr[u*2+1].len*A)%=MOD;
}
void add(ll u,ll &l,ll &r,ll &delta){
    pushdown(u);
    if(r<tr[u].l || tr[u].r<l) return;
    if(l<=tr[u].l && tr[u].r<=r){
        (tr[u].toAdd+=delta)%=MOD;
        (tr[u].sum+=tr[u].len*delta)%=MOD;
        return;
    }
    add(u*2,l,r,delta);
    add(u*2+1,l,r,delta);
    tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum)%MOD;
}
void mul(ll u,ll &l,ll &r,ll &delta){
    pushdown(u);
    if(r<tr[u].l || tr[u].r<l) return;
    if(l<=tr[u].l && tr[u].r<=r){
        (tr[u].toMul*=delta)%=MOD;
        (tr[u].toAdd*=delta)%=MOD;
        (tr[u].sum*=delta)%=MOD;
        return;
    }
    mul(u*2,l,r,delta);
    mul(u*2+1,l,r,delta);
    tr[u].sum=(tr[u*2].sum+tr[u*2+1].sum)%MOD;
}
ll rsq(ll u,ll &l,ll &r){
    pushdown(u);
    if(r<tr[u].l || tr[u].r<l) return 0;
    else if(l<=tr[u].l && tr[u].r<=r) return tr[u].sum;
    return (rsq(u*2,l,r)+rsq(u*2+1,l,r))%MOD;
}
int main(){
    scanf("%lld %lld %lld",&n,&m,&MOD);
    for(int i=1;i<=n;i++) scanf("%d",&x[i]);
    build(1,1,n);
    for(int i=1;i<=m;i++){
        ll t,x,y,z;
        scanf("%lld %lld %lld",&t,&x,&y);
        if(t==3) printf("%lld\n",rsq(1,x,y));
        else if(t==2){
            scanf("%lld",&z);
            add(1,x,y,z);
        }
        else{
            scanf("%lld",&z);
            mul(1,x,y,z);
        }
    }
    return 0;
}

【例 4】区间加,求区间平均值。(洛谷 P14233)

注意这题需要将时间哈希。区间平均值 = 区间和 / 区间长度。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=300009,INF=2e9;
ll n,m,MOD;
struct Segment{
    ll l;
    ll r;
    ll len;
    ll sum;
    ll toAdd;
};
Segment tr[N*4];
void build(ll u,ll l,ll r){
    if(l==r){
        tr[u]=(Segment){l,r,r-l+1,0,0};
        return;
    }
    ll m=(l+r)/2; 
    build(u*2,l,m);
    build(u*2+1,m+1,r);
    tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,0};
}
void pushdown(ll u){
    if(tr[u].l==tr[u].r) return;
    ll A=tr[u].toAdd;
    tr[u].toAdd=0;
    tr[u*2].sum+=tr[u*2].len*A;
    tr[u*2+1].sum+=tr[u*2+1].len*A;
    tr[u*2].toAdd+=A;
    tr[u*2+1].toAdd+=A;
}
void add(ll u,ll l,ll r,ll delta){
    pushdown(u);
    if(r<tr[u].l||tr[u].r<l) return;
    if(l<=tr[u].l&&tr[u].r<=r){
        tr[u].toAdd+=delta;
        tr[u].sum+=tr[u].len*delta;
        return;
    }
    add(u*2,l,r,delta);
    add(u*2+1,l,r,delta);
    tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
}
ll rsq(ll u,ll l,ll r){
    pushdown(u);
    if(r<tr[u].l||tr[u].r<l) return 0;
    else if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
    return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
int main(){
    ll n;
    cin>>n;
    build(1,1,86400);
    for(ll i=1;i<=n;i++){
        ll x,y,z,a,b,c;
        char op;
        cin>>x>>op>>y>>op>>z>>op>>a>>op>>b>>op>>c;
        ll l=x*3600+y*60+z;
        ll r=a*3600+b*60+c;
        l++;r++;
        if(l>r){
            add(1,l,86400,1);
            add(1,1,r,1);
        }
        else add(1,l,r,1); 
    }
    ll q;
    cin>>q;
    for(ll i=1;i<=q;i++){
        ll x,y,z,a,b,c;
        char op;
        cin>>x>>op>>y>>op>>z>>op>>a>>op>>b>>op>>c;
        ll l=x*3600+y*60+z;
        ll r=a*3600+b*60+c;
        l++;r++;
        if(l>r)
            cout<<fixed<<setprecision(10)<<((double)rsq(1,l,86400)+rsq(1,1,r))/(r+86400-l+1)<<endl;
        else
            cout<<fixed<<setprecision(10)<<((double)rsq(1,l,r))/(r-l+1)<<endl;
    }
    return 0;
}

【例 5】区间加,区间 sin 和。(洛谷 P6327)

注意到 \sin(\alpha+\beta)=\sin \alpha \cos\beta+\sin\beta\cos\alpha\cos(\alpha+\beta)=\cos \alpha \cos \beta-\sin \alpha \sin \beta,pushdown 显然。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=300009;
int x[N],n;
struct Segment{
    int l;
    int r;
    double sumsin;
    double sumcos;
    int toAdd;
}tr[N*4];
void build(int u,int l,int r){
    if(l==r){
        tr[u]=(Segment){l,r,sin(x[l]),cos(x[l]),0};
        return;
    }
    int mid=(l+r)/2;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    tr[u]=(Segment){l,r,tr[u*2].sumsin+tr[u*2+1].sumsin,tr[u*2].sumcos+tr[u*2+1].sumcos,0};
}
void sincosadd(int u,double sina,double cosa){
    double x=tr[u].sumsin,y=tr[u].sumcos;
    tr[u].sumsin=x*cosa+y*sina;
    tr[u].sumcos=y*cosa-x*sina;
}
void pushdown(int u){
    int A=tr[u].toAdd;
    tr[u].toAdd=0;
    if(!A) return;
    double sinx=sin(A),cosx=cos(A);
    sincosadd(u*2,sinx,cosx);
    sincosadd(u*2+1,sinx,cosx);
    tr[u*2].toAdd+=A;
    tr[u*2+1].toAdd+=A;
}
void add(int u,int l,int r,int delta){
    pushdown(u);
    if(tr[u].r<l||r<tr[u].l) return;
    if(l<=tr[u].l&&tr[u].r<=r){
        sincosadd(u,sin(delta),cos(delta));
        tr[u].toAdd+=delta;
        return;
    }
    add(u*2,l,r,delta);
    add(u*2+1,l,r,delta);
    tr[u].sumsin=tr[u*2].sumsin+tr[u*2+1].sumsin;
    tr[u].sumcos=tr[u*2].sumcos+tr[u*2+1].sumcos;
}
double rsq(int u,int l,int r){
    pushdown(u);
    if(tr[u].r<l||r<tr[u].l) return 0;
    if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sumsin;
    return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
signed main(){
    ios::sync_with_stdio(0);
    cin>>n;
    for(int i=1;i<=n;i++) cin>>x[i];
    build(1,1,n);
    int m;
    cin>>m;
    for(int i=1;i<=m;i++){
        int op,l,r;
        cin>>op>>l>>r;
        if(op==2) cout<<fixed<<setprecision(1)<<rsq(1,l,r)<<'\n';
        else{
            int v;
            cin>>v;
            add(1,l,r,v);
        }
    }
    return 0;
}

【例 6】区间加,区间方差。(洛谷 P1471)

简单的题目,平均值极其简单,方差 = 平方和/区间长度-平均值*平均值。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll MOD=1e9+7;
const ll N=300009;
struct Segment{
    ll l;
    ll r;
    ll len;
    double sum;
    double sqsum;
    double toAdd;
}tr[N*4];
double a[N];
ll n,m;
void build(ll u,ll l,ll r){
    if(l==r){
        tr[u]=(Segment){l,r,r-l+1,a[l],a[l]*a[l],0};
        return;
    }
    ll mid=(l+r)/2;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
    tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,tr[u*2].sqsum+tr[u*2+1].sqsum,0};
}
void pushdown(ll u){
    double A=tr[u].toAdd;
    if(!A) return;
    tr[u].toAdd=0;
    tr[u*2].sqsum+=2*A*tr[u*2].sum+tr[u*2].len*A*A;
    tr[u*2+1].sqsum+=2*A*tr[u*2+1].sum+tr[u*2+1].len*A*A;
    tr[u*2].sum+=tr[u*2].len*A;
    tr[u*2+1].sum+=tr[u*2+1].len*A;
    tr[u*2].toAdd+=A;
    tr[u*2+1].toAdd+=A;
}
void add(ll u,ll l,ll r,double delta){
    pushdown(u);
    if(r<tr[u].l||tr[u].r<l) return;
    if(l<=tr[u].l&&tr[u].r<=r){
        tr[u].sqsum+=2*delta*tr[u].sum+delta*delta*tr[u].len;
        tr[u].sum+=tr[u].len*delta;
        tr[u].toAdd+=delta;
        return;
    }
    add(u*2,l,r,delta);
    add(u*2+1,l,r,delta);
    tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
    tr[u].sqsum=tr[u*2].sqsum+tr[u*2+1].sqsum;
}
double rsq1(ll u,ll l,ll r){
    pushdown(u);
    if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
    else if(tr[u].r<l||r<tr[u].l) return 0;
    else return rsq1(u*2,l,r)+rsq1(u*2+1,l,r);
}
double rsq2(ll u,ll l,ll r){
    pushdown(u);
    if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sqsum;
    else if(tr[u].r<l||r<tr[u].l) return 0;
    else return rsq2(u*2,l,r)+rsq2(u*2+1,l,r);
}
int main(){
    cin>>n>>m;
    for(ll i=1;i<=n;i++) cin>>a[i];
    build(1,1,n);
    for(ll i=1;i<=m;i++){
        ll op,l,r;
        cin>>op>>l>>r;
        if(op==1){
            double delta;
            cin>>delta;
            add(1,l,r,delta);
        }
        else if(op==2) cout<<fixed<<setprecision(4)<<rsq1(1,l,r)/(r-l+1)<<endl;
        else{
            double ave=rsq1(1,l,r)/(r-l+1);
            cout<<fixed<<setprecision(4)<<((rsq2(1,l,r)/(r-l+1))-ave*ave)<<endl;
        }
    }
    return 0;
}

【例 7】区间加,区间求和,区间求 min/max。(洛谷 P3130)

pushdown 在例 2 基础上,min 值就加上懒标记即可。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=800009,INF=2e9;
ll n,m,MOD;
struct Segment{
    ll l;
    ll r;
    ll len;
    ll sum;
    ll mn;
    ll toAdd;
};
Segment tr[N*4];
ll x[N];
void build(ll u,ll l,ll r){
    if(l==r){
        tr[u]=(Segment){l,r,r-l+1,x[l],x[l],0};
        return;
    }
    ll m=(l+r)/2; 
    build(u*2,l,m);
    build(u*2+1,m+1,r);
    tr[u]=(Segment){l,r,r-l+1,tr[u*2].sum+tr[u*2+1].sum,min(tr[u*2].mn,tr[u*2+1].mn),0};
}
void pushdown(ll u){
    if(tr[u].l==tr[u].r) return;
    ll A=tr[u].toAdd;
    tr[u].toAdd=0;
    tr[u*2].sum+=tr[u*2].len*A;
    tr[u*2+1].sum+=tr[u*2+1].len*A;
    tr[u*2].mn+=A;
    tr[u*2+1].mn+=A;
    tr[u*2].toAdd+=A;
    tr[u*2+1].toAdd+=A;
}
void add(ll u,ll&l,ll&r,ll&delta){
    pushdown(u);
    if(r<tr[u].l||tr[u].r<l) return;
    if(l<=tr[u].l&&tr[u].r<=r){
        tr[u].toAdd+=delta;
        tr[u].sum+=tr[u].len*delta;
        tr[u].mn+=delta;
        return;
    }
    add(u*2,l,r,delta);
    add(u*2+1,l,r,delta);
    tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
    tr[u].mn=min(tr[u*2].mn,tr[u*2+1].mn);
}
ll rsq(ll u,ll&l,ll&r){
    pushdown(u);
    if(r<tr[u].l||tr[u].r<l) return 0;
    else if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
    return rsq(u*2,l,r)+rsq(u*2+1,l,r);
}
ll rmq(ll u,ll&l,ll&r){
    pushdown(u);
    if(r<tr[u].l||tr[u].r<l) return INF;
    else if(l<=tr[u].l&&tr[u].r<=r) return tr[u].mn;
    return min(rmq(u*2,l,r),rmq(u*2+1,l,r));
}
int main(){
    scanf("%lld %lld",&n,&m);
    for(ll i=1;i<=n;i++) scanf("%lld",&x[i]);
    build(1,1,n);
    for(ll i=1;i<=m;i++){
        char t;
        ll x,y,z;
        cin>>t>>x>>y; 
        if(t=='M') cout<<rmq(1,x,y)<<endl;
        else if(t=='S') cout<<rsq(1,x,y)<<endl;
        else{
            cin>>z;
            add(1,x,y,z);
        }
    }
    return 0;
}