线段树-1

· · 个人记录

目录

总述

树状数组的本质是利用二进制进行区间分块。而对于更一般的区间问题,可以使用额外空间存储整个区间信息,高效率O(logn)修改、查询区间[l,r]的信息,线段树就显示出巨大优势,与树状数组相比更加通用。

接下来,利用线段树维护区间和和区间最大值为例,分析线段树创建、修改、查询。

开始正题

线段树的使用

(0)线段树相关函数定义及主函数中的调用

```cpp struct node{ int l,r; long long data; }t[4*N]; int main() { cin>>n>>m; for(int i=1;i<=n;i++)cin>>a[i]; build(1,1,n);//建树 while(m--) { int op;//任务 cin>>op; if(单点修改) { int x,y; add(1,x,y);//单点修改 } if(区间查询) { int x,y; cout<<query(1,x,y);//区间查询 } } 其他 } ``` ------------ ## (1)建树 给定一个长度为$N$的序列$a$,我们可以创建在区间$[1,N]$上创建一颗线段树,每个叶子节点$[i,i]$保存$a[i]$的值,线段树的二叉树结构很方便从下往上传递信息,即 $data[l,r]=max(data[l,mid],data[mid+1,r])$,也就是$t[p].data=max(t[2*p].data,t[2*p+1].data)$. 对于区间求和来说,即 $data[l,r]=data[l,mid]+data[mid+1,r]$,也就是 $t[p].data=t[2*p].data+t[2*p+1].data

建树相对简单,直接看代码

区间求和版

inline void build(int p,int l,int r)//建树 
{
    t[p].l=l,t[p].r=r;
    if(l==r)//到叶子了
    {
        t[p].date=a[l];
        return;
    } 
    int mid=l+r>>1;
    build(p<<1,l,mid);//建左树 p<<1=p*2
    build(p<<1|1,mid+1,r);//建右树 p<<1|1=p*2+1
    t[p].data=max(t[p<<1].data,t[p<<1|1].data)//从下往上回传信息 
}

区间最大值版

inline void build(int p,int l,int r)
{
    t[p].l=l,t[p].r=r;      //节点p表示区间[l,r]
    if(l==r)             //叶子
    {
        t[p].data=a[l];
        return;
    }
    int mid= l+r >>1;         //mid=(l+r)/2
    build(p<<1,l,mid);        //2*p = p<<1 
    build(p<<1|1,mid+1,r);     //2*p+1 = p<<1|1
    t[p].dat=max(t[p<<1].data,t[p<<1|1].data);  //从下往上传回信息
}
$Answer 1 O(n)

(2)单点修改

单点修改就是将a[x]加上一个数d.

我们需要从根节点出发,递归找到区间[x,x]也就是叶子节点,修改后,再从下往上更新其祖先节点。执行次数就是整个树的深度,因此时间复杂度为O(logn).

单点修改也比较简单没不做过多解释,直接上代码

区间求和版

inline void add(int p,int x,int d)//单点修改,给a[x]+d
{
    if(t[p].l==t[p].r)//到叶子了 
    {
        if(t[p].l==x)//加不加都行,重在要理解,不过加上好像速度更快了 
        {
            t[p].data+=d;
            return ;
        }
    }
    int mid=t[p].l+t[p].r>>1;//左右断点 
   if(x<=mid)add(p<<1,x,d);//x属于左半区间 
   else add(p<<1|1,x,d);//x属于右半区间  
   t[p].data=max(t[p<<1].data+t[p<<1|1].data);//从下往上更新信息 
}

区间最大值版

inline void add(int p,int x,int d)  // 单点修改 给a[x]加d 
{
    if(t[p].l==t[p].r)
    {
        t[p].data+=d;
        return; 
    }
    int mid=t[p].l+t[p].r >>1;
    if(x<=mid)add(p<<1,x,d);   //x属于左半区间 
    else add(p<<1|1,x,d);     //x属于右半区间 
    t[p].data=max(t[p<<1].data,t[p<<1|1].data); 
}

(3)区间查询

查询区间[l,r]的最大值,只需要从根节点的出发,递归执行以下过程:

注意开long long ,可能有些题不用,视情况而定

QWQ

上代码

区间求和版

inline long long query(int p,int L,int R)//区间查询 区分大小写L和R为题目中查询的区间边界 
{
    if(L<=t[p].l&&t[p].r<=R)return t[p].data;
    int mid=t[p].l+t[p].r>>1;
    int ret=0;
    if(L<=mid)ret=max(ret,query(p<<1,L,R));//左子节点有重叠 
    if(mid<R)ret=max(ret,query(p<<1|1,L,R));//右子节点有重叠 
    return ret;
}

区间最大值版

inline long long query(int p,int l,int r)
{
    if(l<=t[p].l&&t[p].r<=r)return t[p].data;
    int mid=t[p].l+t[p].r >>1;
    int ans=-(1<<30);
    if(l<=mid)ans=max(ans,query(p<<1,l,r));
    if(mid<r)ans=max(ans,query(p<<1|1,l,r));
    return ans;
}
$Answer 2$ 实际该过程会把询问区间$[l,r]$在线段树上分成$O(logN)$个节点,取其最大值作为答案,因此复杂度为$O(logn)$。(分情况讨论就可以证明)证明如下(忽略红线): 对于节点$p$,负责的区间表示为$[p_l,p_r]$,设$mid=\lfloor(p_l+p_r)/2\rfloor

(4)区间修改 --lazy延迟标记

需要对一个区间[l,r]进行修改时,如果暴力将区间修改做单点修改,区间修改一次复杂度就为O(nlogn),m次修改复杂度无法接受(比单纯数组修改复杂度还大)。

实际上,如果修改区间[l,r],完全覆盖了节点p所代表的区间[p_l,p_r],而在查询的时候查询区间也完全覆盖了节点p所代表的区间,我们没必要修改节点p以下的所有子树节点,这会做很多无用功。

也就是,在执行修改指令时,同样可以在 l \le p_l \le p_r \le r情况下立即返回,只不过在回溯之前在节点p增加一个标记,标识“该节点曾经被修改,但其子节点尚未被更新。”

但是在查询时,如果需要从节点p向下递归,我们再检查p是否具有标记。如果有标记,根据标记信息更新p的两个子节点,同时为p的两个子节点增加标记,然后清除p的标记。

也就是说,除了在修改指令中直接划分成O(logN)个节点之外,对于任意节点修改都延迟到“在后续操作中需要递归其子孙节点时,再修改”,这样每条查询或修改(区间修改)指令时间复杂度都降低到了O(logN),这个标记就是延迟标记。

注意结构体的定义需要改变qwq!!!

struct node{
    int l,r;
    long long lazy,data;
}t[4*N];
inline void pushup(int p)
{
    t[p].data=t[p<<1].data+t[p<<1|1].data;
}
inline void pushdown(int p)
{
    if(t[p].lazy)
    {
        t[p<<1].data+=t[p].lazy*(t[p<<1].r-t[p<<1].l+1)//更新左子树
        t[p<<1|1].data+=t[p].lazy*(t[p<<1|1].r-t[p<<1|1].l+1)//更新右子树
        t[p<<1].lazy+=t[p].lazy;//下传标记
        t[p<<1|1].lazy+=t[p].lazy;
        t[p].lazy=0;//注意清空本节点 
        return ; 
    }
}
inline void add(int p,int L,int R,int d)//区间修改
{
    if(l<=t[p].L&&t[p].r<=R)//完全包含
    {
        t[p].data+=d*(t[p].r-t[p].l+1);
        t[p].lazy+=d;
        return;
    } 
    pushdown(p);
    int mid=t[p].l+t[p].r>>1;
    if(l<=mid)add(p<<1,l,r,d);
    if(mid<r)add(p<<1|1,l,r,d);
    pushup(p);//从左右子树更新根节点信息 
}
inline long long query(int p,int L,int R)//区间查询
{
    if(l<=t[p].L&&t[p].r<=R)return t[p].data;//完全包含,直接返回 
    pushdown(p);
    int mid=t[p].l+t[p].r>>1;
    long long ret=0;
    if(l<=mid)ret+=query(p<<1,l,r);
    if(mid<r)ret+=query(p<<1|1,l,r);
    return ret;
} 

线段树的应用(例题)

例题一

题意

GSS3 - Can you answer these queries III

给定义长度为N的序列A,m条指令,"1 x y"表示查询区间[x,y]中最大连续子段和,"2 x y"表示将A[x]改成y

分析

每个节点除了维护区间端点外,再维护4个信息:区间和data,区间最大连续子段和sum,紧靠左端的最大连续子段和lmax,紧靠右端的最大连续子段和rmax。 从下往上更新这4个信息:

t[p].data=t[p<<1].data+t[p<<1|1].data;
t[p].maxl=max(t[p<<1].maxl,t[p<<1].data+t[p<<1|1].maxl);
t[p].maxr=max(t[p<<1|1].maxr,t[p<<1|1].data+t[p<<1].maxr);
t[p].sum=max(max(t[p<<1].sum,t[p<<1|1].sum),t[p<<1].maxr+t[p<<1|1].maxl);

通常,我们将从下往上更新的内容放在pushup函数中,因此定义pushup函数

最终实现

点击参考代码或直接看下方

#include<bits/stdc++.h>
using namespace std;
const int N=5e4+10;
struct node{
    int l,r;
    long long data,maxl,maxr,sum;
}t[4*N];
int n,m;
int a[N];
inline void pushup(int p)
{
    t[p].data=t[p<<1].data+t[p<<1|1].data;
    t[p].maxl=max(t[p<<1].maxl,t[p<<1].data+t[p<<1|1].maxl);
    t[p].maxr=max(t[p<<1|1].maxr,t[p<<1|1].data+t[p<<1].maxr);
    t[p].sum=max(max(t[p<<1].sum,t[p<<1|1].sum),t[p<<1].maxr+t[p<<1|1].maxl);
}
inline void build(int p,int l,int r)//建树 
{
    t[p].l=l,t[p].r=r;
    if(l==r)
    {
        t[p].sum=t[p].data=t[p].maxl=t[p].maxr=a[l];
        return ;
    }
    int mid=l+r>>1;
    build(p<<1,l,mid);
    build(p<<1|1,mid+1,r);
    pushup(p);
}
inline void update(int p,int x,int d)//区间修改 
{
    if(t[p].l==t[p].r)
    {
        t[p].sum=t[p].data=t[p].maxl=t[p].maxr=d;
        return;
    }
    int mid=t[p].l+t[p].r>>1;
    if(x<=mid)update(p<<1,x,d);
    else update(p<<1|1,x,d);
    pushup(p);
}
inline node query(int p,int L,int R)//区间查询 
{
    if(L<=t[p].l&&t[p].r<=R)return t[p];
    int mid=t[p].l+t[p].r>>1;
    int ret=-(1<<30);
    if(L>mid)return query(p<<1|1,L,R);
    else if(R<=mid)return query(p<<1,L,R);
    else 
    {
        node a,b;
        a=query(p<<1,L,R);
        b=query(p<<1|1,L,R);
        node tt;
        tt.data=a.data+b.data;
        tt.maxl=max(a.maxl,a.data+b.maxl);
        tt.maxr=max(b.maxr,b.data+a.maxr);
        tt.sum=max(max(a.sum,b.sum),a.maxr+b.maxl);
        return tt;
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)scanf("%d",a+i);
    build(1,1,n);
    scanf("%d",&m);
    while(m--)
    {
        int op,x,y;
        scanf("%d%d%d",&op,&x,&y);
        if(op==0)
            update(1,x,y);
        if(op==1)
            printf("%d\n",query(1,x,y).sum);
    }
    return 0;
}

学会了吗?

参考文献 李煜东《算法竞赛进阶指南》&& 技术支持博客

The end