线段树入门

· · 算法·理论

前言

本文主要帮助像我一样的蒟蒻更好的入门线段树,所以内容较基础,大佬们看到这里已经可以直接划走了。

学习基础

简介

线段树,顾名思义,就是一棵处理线段的树形结构。用于高效处理区间问题,只要是具有可合并性 的区间问题都可以用它解决,常见的如区间加、区间乘、区间修改、区间最值等等,学习线段树是你从蒟蒻晋升大佬的主要台阶,那么,就让我们正式开始学习。

例题一

那么让我们来看一道题:

已知一个数列,你需要进行下面两种操作:

大家的第一反应可能是暴力,每次询问就从l枚举到r计算所有的值,但每次询问O(n)的复杂度对于大规模数据来说有些太慢了。

聪明一点的人可能会想到前缀和,他的询问是O(1)的没错,但多次操作中它具有O(n)的修改复杂度,所以本质上复杂度和上一种方法没有区别。

那么这个时候我们就需要本篇的主角——树状数组线段树了。

思想

线段树是一棵二叉树,树中的每一个结点表示了一个区间[l,r],若l=r,说明这个点是个叶子节点,否则对于节点tr[u],他的左右节点分别是tr[u \times 2]tr[u \times 2 + 1]

由于每一层线段树每个节点所维护的区间长度都是父亲节点的一半,所以线段树深度为 log_2(n),那么他就会有2^{log_2(n)+1}-1个节点,当然这些都不重要,你只用记住线段树数组大小要开到4n就行了(n为区间大小)。

例题一实现

基础结构定义

struct node{
    int l, r; //左右儿子
    int val;//当前存的值
    //int lazy; <- 先不用管,等会有用
}tr[N << 2];//四倍空间

等会我们要用到的函数:

void pushup(int x);//将数据向上传递到x
void build(int u, int l, int r);//初始化[l,r]的值
void change(int u, int x,int k);//将a[x]修改为k
int query(int u, int l, int r);//查询[l,r]的值

向上传递数据

void pushup(int x){
    tr[x].val = tr[x * 2].val + tr[x * 2 + 1].val;//根据儿子更新父亲
}

建树我们用递归实现

void build(int u, int l, int r){
    tr[u] = {l, r, 0};//初始化
    if(l == r){//判断是否是叶子节点
        tr[u].val = a[l];//是叶子节点直接赋值
        return;
    }
    int mid = (l + r) / 2;
    build(u * 2, l, mid), //否则递归左右儿子
    build(u * 2 + 1, mid + 1, r);
    pushup(u);//将数据上传到u
}

单点修改

void change(int u, int x, int k){
    if(tr[u].l == tr[u].r){//为叶子节点
        tr[u].val = k;//直接修改
        return;
    }
    int mid = (tr[u].l + tr[u].r) / 2;
    if(x <= mid) change(u * 2, x, k);//否则递归修改左右节点(有点像建树)
    else change(u * 2 + 1, x, k);
    pushup(u);//数据上传
}

区间查询

int query(int x, int l, int r){
    if(tr[x].l >= l && tr[x].r <= r) return tr[x].val; //若区间完全被查询区间包含,返回值
    int mid = (tr[x].l + tr[x].r) / 2, sum = 0;
    if(l <= mid) sum += query(x * 2, l, r);//否则分别查询左右儿子
    if(r > mid) sum += query(x * 2 + 1, l, r);
    return sum;//返回结果
}

但是,线段树的强大之处还不止这些,不然我还不如去写树状数组

例题二

让我们看看这题: 已知一个数列,你需要进行下面两种操作:

那么学完线段树的你很容易就能想出区间赋值的代码:

void update(int l, int r, int k){
    for(int i = l; i <= r; i++)
        change(1, i, k);
}

然后发现TLE了

因为这种方法的复杂度达到了O(nlogn),那么怎么优化呢?
我们注意到在查询之前,节点的值都是没有用的,那么我们可以在节点赋值后,暂时先不赋值,先“拖延”一会,等上面催了要查询的时候再赋值。是不是和你一模一样

这就是懒标记优化,我们在赋值时仅仅更新懒标记,而不真正的更新节点,这样我们的复杂度就达到了O(logn)

例题二实现

定义

struct node{
    int l, r, 
        val, lazy;//lazy表示这个区间每个点要累加的值
}tr[N << 2];

向下更新

void pushdown(int x){//从x向下更新儿子节点
    if(tr[x].lazy){//如果有懒标记(其实这个判断删了也行)
        tr[2 * x].lazy += tr[x].lazy, //左右儿子继承懒标记
        tr[2 * x + 1].lazy += tr[x].lazy,
        tr[2 * x].val += tr[x].lazy * (tr[2 * x].r - tr[2 * x].l + 1),//左右儿子实际值更新成懒标记*这个区间的节点数量
        tr[2 * x + 1].val += tr[x].lazy * (tr[2 * x + 1].r - tr[2 * x + 1].l + 1);
        tr[x].lazy = 0;//别忘了清空懒标记!!!
    }
}

那么我们的区间赋值就是这样实现的

void update(int u, int l, int r, int k){//将[l,r]依次加上k
    if(l <= tr[u].l &&  tr[u].r <= r)//在区间内
        tr[u].val += k * (tr[u].r - tr[u].l + 1),
        tr[u].lazy += k;//仅赋值当前点,然后更新懒标记
    else{
        pushdown(u);//因为即将遍历到子节点,必须保证子节点的值是最新的
        int mid = (tr[u].l + tr[u].r) / 2;
        if(l <= mid) update(u * 2, l, r, k);//遍历子节点
        if(r > mid) update(u * 2 + 1, l, r, k);
        pushup(u);//子节点更新后重新计算父节点
    }
}

然后区间查询也要改一下

int query(int x, int l, int r){
    if(tr[x].l >= l && tr[x].r <= r) return tr[x].val;
    pushdown(x);//仅加了这一行,即将遍历到子节点赶紧更新
    int mid = (tr[x].l + tr[x].r) / 2, sum = 0;
    if(l <= mid) sum += query(x * 2, l, r);
    if(r > mid) sum += query(x * 2 + 1, l, r);
    return sum;
}

然后就搞定了,我们再来看一道题:

例题三

已知一个数列,你需要进行下面两种操作:

我们注意到max(a,b,c,d)=max(max(a,b),max(c,d)),也就是说,一个区间的最大值可以通过他的儿子区间得出。

所以其实这题只要改几个符号就好了

void pushup(int x){
    tr[x].val = max(tr[x * 2].val, tr[x * 2 + 1].val);//改成求最大值
}

int query(int x, int l, int r){
    if(tr[x].l >= l && tr[x].r <= r) return tr[x].val;
    pushdown(x);
    int mid = (tr[x].l + tr[x].r) / 2, sum = -1;
    if(l <= mid) sum = max(sum, query(x * 2, l, r));
    if(r > mid) sum = max(sum, query(x * 2 + 1, l, r));
    return sum;
}//同理

习题

下面给几道课后习题:

P3374
P3368
P1253
P3372
P3373
P1198