zkw 线段树学习笔记

· · 个人记录

介于某些高素质人群使用 zkw 写标程的行为令人不齿,被迫学习。。。

以 P3372 为例,我最快的递归线段树用了 319ms,内存开了 11.49MB。然后 zkw 线段树在时间和空间上都是递归版的 \dfrac{1}{2} 左右。

虽然可能在常数方面写的很烂,但确实非递归版要快一些。

总结一下 zkw 线段树,大概就是:

性质:堆存储,下标加上 M 可直接对应叶节点。

非递归实现方式:自底向上。

区间的修改查询等操作:标记永久化。

一些细节:使用开区间,类似树状数组。

然后记住一个就是 s 若左则其右兄弟是区间内的,t 若右则其左兄弟是区间内的。

其实就是一用来卡常的线段树。。。

参考文献:

  1. 线段树的扩展之浅谈zkw线段树

  2. 统计的力量(ppt)

代码:

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=1e5;

ll n,m,op,x,y,k,M;

struct sgt{
    ll val,laz;
    #define val(x) tree[x].val
    #define laz(x) tree[x].laz
}tree[N*4+5];

void add(ll l,ll r,ll k) {
    ll s=M+l-1,t=M+r+1,nnum=1,lnum=0,rnum=0;
    for(;s^t^1;s>>=1,t>>=1,nnum<<=1) {
        val(s)+=k*lnum;val(t)+=k*rnum;
        if(~s&1) {laz(s^1)+=k;val(s^1)+=k*nnum;lnum+=nnum;}
        if(t&1) {laz(t^1)+=k;val(t^1)+=k*nnum;rnum+=nnum;}
    }
    for(;s;s>>=1,t>>=1) {
        val(s)+=k*lnum;
        val(t)+=k*rnum;
    }
}

ll query(ll l,ll r) {
    ll s=M+l-1,t=M+r+1,nnum=1,lnum=0,rnum=0,ans=0;
    for(;s^t^1;s>>=1,t>>=1,nnum<<=1) {
        if(laz(s)) ans+=laz(s)*lnum;
        if(laz(t)) ans+=laz(t)*rnum;
        if(~s&1) {ans+=val(s^1);lnum+=nnum;}
        if(t&1) {ans+=val(t^1);rnum+=nnum;}
    }
    for(;s;s>>=1,t>>=1) {
        ans+=laz(s)*lnum;
        ans+=laz(t)*rnum;
    }
    return ans;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=-(x%10)+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

void writeln(ll x) {
    write(x);putchar('\n');
}

void build() {
    M=1;
    for(;M<=n+1;M<<=1);
    for(ll i=1;i<=n;i++) {val(i+M)=read();}
    for(ll i=M-1;i>=1;i--) {val(i)=val(i<<1)+val(i<<1|1);}
}

int main() {

    n=read();m=read();

    build();

    while(m--) {
        op=read();x=read();y=read();
        if(op==1) {
            k=read();add(x,y,k);
        }
        if(op==2) {
            writeln(query(x,y));
        }
    }

    return 0;
}

代码(轻压行版本):

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;

const ll N=1e5;

ll n,m,op,x,y,k,M;

struct sgt{
    ll val,laz;
    #define val(x) tree[x].val
    #define laz(x) tree[x].laz
}tree[N*4+5];

void add(ll l,ll r,ll k) {
    ll s=M+l-1,t=M+r+1,nnum=1,lnum=0,rnum=0;
    for(;s^t^1;s>>=1,t>>=1,nnum<<=1) {
        val(s)+=k*lnum;val(t)+=k*rnum;
        if(~s&1) {laz(s^1)+=k;val(s^1)+=k*nnum;lnum+=nnum;}
        if(t&1) {laz(t^1)+=k;val(t^1)+=k*nnum;rnum+=nnum;}
    }
    for(;s;s>>=1,t>>=1) {val(s)+=k*lnum;val(t)+=k*rnum;}
}

ll query(ll l,ll r) {
    ll s=M+l-1,t=M+r+1,nnum=1,lnum=0,rnum=0,ans=0;
    for(;s^t^1;s>>=1,t>>=1,nnum<<=1) {
        if(laz(s)) ans+=laz(s)*lnum;if(laz(t)) ans+=laz(t)*rnum;
        if(~s&1) {ans+=val(s^1);lnum+=nnum;}
        if(t&1) {ans+=val(t^1);rnum+=nnum;}
    }
    for(;s;s>>=1,t>>=1) {ans+=laz(s)*lnum;ans+=laz(t)*rnum;}
    return ans;
}

inline ll read() {
    ll ret=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
    while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
    return ret*f;
}

void write(ll x) {
    static char buf[22];static ll len=-1;
    if(x>=0) {
        do{buf[++len]=x%10+48;x/=10;}while(x);
    }
    else {
        putchar('-');
        do{buf[++len]=-(x%10)+48;x/=10;}while(x);
    }
    while(len>=0) putchar(buf[len--]);
}

void writeln(ll x) {
    write(x);putchar('\n');
}

void build() {
    M=1;
    for(;M<=n+1;M<<=1);
    for(ll i=1;i<=n;i++) val(i+M)=read();
    for(ll i=M-1;i>=1;i--) val(i)=val(i<<1)+val(i<<1|1);
}

int main() {

    n=read();m=read();

    build();

    while(m--) {
        op=read();x=read();y=read();
        if(op==1) {k=read();add(x,y,k);}
        if(op==2) {writeln(query(x,y));}
    }

    return 0;
}