线段树入门

· · 个人记录

我们一般维护序列区间和是这么做的:

int query(int l, int r){
    int res = 0;
    for(int i = l; i <= r; ++ i) res += a[i];
    return res;
}
void update(int l, int r, int k){
    for(int i = l; i <= r; ++ i) a[i] = k;
    return;
}

然而复杂度实在是太高了。

现在考虑用如下的方法(只考虑查询)。

假设序列长度为 n, 我们把前 \lfloor \dfrac{n}{2} \rfloor 个元素分为一组,另外的分为另一组。

如下所示:

\{a_1,a_2,a_3,a_4,……a_{\lfloor\frac{n}{2}\rfloor}\},\{a_{\lfloor\frac{n}{2}\rfloor + 1},a_{\lfloor\frac{n}{2}\rfloor + 2},……,a_n\}

将这两个组的存为 sum_1sum_2

那么对于形如 query(1,x)(x>\lfloor\frac{n}{2}\rfloor)query(x,n)(x<=\lfloor\frac{n}{2}\rfloor) 的询问,我们可以将他们变成如下的代码:

int query(int l, int r){
    int res = 0;
    if(l == 1 && r > mid/*mid即指floor(mid)*/)
        res = sum[1],
        l = mid + 1;
    else if(r == n && l <= mid)
        res = sum[2],
        r = mid - 1;
    for(int i = l; i <= r; ++ i) res += a[i];
    return res;
}

可以发现,这样子我们优化了一点点的常数。

接下来,我们考虑这么分组:

对于分出来的每一组,都采用如上的方法分成两组,直到不能再分为止。

n4,则如图一所示:

可以看到,这里一共分为了 7 组,分别存为sum_1,sum_2,sum_3……sum_7

可以发现如下的关系:

sum_1=sum_2+sum_3, sum_2=sum_4+sum_5, sum_3=sum_6+sum_7, ……

归纳一下,可以发现 sum_p=sum_{p\times2}+sum_{p \times2+1} (相应的,对于 n8 的区间,也存在这个关系)

我们可以把这个结构看成一棵二叉树:

那么对于每一个节点 p,左儿子的编号就为 p\times2,右儿子的编号就是 p\times2+1,而每个节点的 sum 就为左右儿子的和。值得一提的是,对于叶子节点的 sum 就是在 a 数组中对应的值。

特殊的,对于 n 为其他值的情况,这棵树的样子就像这样(图二):

我们称这样的一棵树为线段树

我们来看一下线段树节点的数量:

如图,其中右边的数字为线段树中每一层节点的数量。那么节点数量的公式就是:

1+2+4+……2^{k}(k=\log n\Leftrightarrow2^{k}\approx n)

S=1+2+4+……2^{k},则 2S=2+4+8+……2^{k+1}

S=1+2+4+……2^{k}\space\space\space\space\space\space(1)\\ 2S=2+4+8+……2^{k+1}\space(2) \end{cases} 然而就像图二一样,期望上树的高度是 $\log n$ 的,但是实际上我们会发现总会有一些节点的深度会比 $\log n$ 大 $1$。所以在开 $sum$ 数组时,我们一般开四倍于 $n$ 的空间。 建树的代码如下: ```cpp int sum[N << 2];//四倍 void push_up(int p){ sum[p] = sum[p << 1] + sum[p << 1 | 1]; return; } void build(int p, int l, int r){ //len[p] = r - l + 1 后面要用。 if(l == r){ sum[p] = a[l]; return; } int mid = l + r >> 1; build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r); push_up(p); return; } ``` 其中 `>>、<<、|` 是位运算。 - `<<` 叫做左移,$x$ 左移 $y$ 位的意思是 $x\times2^{y}$; - `>>` 叫做右移,$x$ 右移 $y$ 位的意思是 $\lfloor\dfrac{x}{2^{y}}\rfloor$; - `|` 叫做(按位)或,基本位运算; 在这里 `p << 1` 即 `p * 2`,`p << 1 | 1` 即 `p * 2 + 1`。 那么对于每次询问,我们都可以通过某种方式,将任意的区间拆分为不超过 $O(\log n)$ 个的线段树上节点。 ~~伪~~证明如下: 设询问的元素为 $a_l,a_{l+1},a_{l+2},……a_{r-1},a_r$。 观察线段树的结构,发现节点层数每上升,下一层**两两相邻**的节点都会合并成一个大节点。以叶子节点为例,分以下四种情况(设 $n$ 为 $2^k(k>0)$,即 $n$ 为 $2$ 的正整数次幂): - 如果 $l$ 为奇数,$r$ 为偶数,那么对应到上一层的结点的序列就为: $$\{a_l,a_{l+1}\},\{a_{l+2},a_{l+3}\},\{a_{l+4},a_{l+5}\}……\{a_{r-1},a_{r}\}$$ 可以发现转移到上一层进行处理的节点数为 $O(\lfloor\frac{r-l+1}{2}\rfloor) \{a_l,a_{l+1}\},\{a_{l+2},a_{l+3}\},\{a_{l+4},a_{l+5}\}……\{a_{r-2},a_{r-1}\},a_r

这种情况下,对于 a_r 的修改可以忽略不计,节点数也为 O(\lfloor\frac{r-l+1}{2}\rfloor)

a_l,\{a_{l+1},a_{l+2}\},\{a_{l+3},a_{l+4}\},\{a_{l+5},a_{l+6}\}……\{a_{r-2},a_{r-1}\},a_r

同理,忽略 a_la_r,节点数是 O(\lfloor\frac{r-l+1}{2}\rfloor)

a_l,\{a_{l+1},a_{l+2}\},\{a_{l+3},a_{l+4}\},\{a_{l+5},a_{l+6}\}……\{a_{r-1},a_{r}\}

还是一样,上面那层的节点数为 O(\lfloor\frac{r-l+1}{2}\rfloor)

总的来说,每往上一层,需要访问的节点数会缩小一半,而每层对复杂度有贡献的节点的复杂度是 O(1)。所以分裂出的区间总是 O(\log (r-l+1)) 的。而询问的区间长度最长可以取到 n,所以复杂度的上界就是 O(\log n)

也就是说,对于每次询问,我们只要求出线段树上的 O(\log n) 个节点的和,而不是原来的 O(n) 个元素,复杂度大大优化。

以该图为例:

假如我们查询 [2,4] 区间的话,我们只要求出 sum_5+sum_3 的值,而非暴力版本的 a_2+a_3+a_4。当然现在我们的例子规模比较小,放到更大的数据里优化会更明显。

现在我们来讲解一下查询的代码:

int query(int p, int l, int r, int x, int y/*x和y代表从查询区间*/){
    if(y < l || r < x) return 0;
    if(x <= l && r <= y) return sum[p];
    int mid = l + r >> 1;
    return query(p << 1, l, mid, x, y) + query(p << 1 | 1, mid + 1, r, x, y);
}
//输出示范
cout << query(1, 1, n, l, r);

作者所采用的线段树写法中每个函数需要对 l,r 进行一次计算,从树根的 [1,n] 开始,左儿子的区间就为 [l,mid],右儿子的区间就为 [mid+1,r]。从树根开始搜索,分别遍历左右儿子,分以下三种情况:

在这种情况下,单点修改的代码如下:

void update(int p, int l, int r, int x, int k){//把x位置上的数改为k
    if(l == r){//已经是叶子节点
        sum[p] = k;
        return;
    }
    int mid = l + r >> 1;
    if(x <= mid) update(p << 1, l, mid, x, k);
    else update(p << 1 | 1, mid + 1, r, x, k);
    push_up(p);
    return; 
}
update(1, 1, n, x, k);//修改

从根节点开始遍历,分以下三种情况:

别忘记 push_up 合并修改过的节点。

很容易可以看出,时间复杂度就是线段树的深度即 O(\log n)。至此,我们解决了单点修改,区间查询的问题。

现在我们来对区间修改进行优化。可以看到,这样的线段树的区间修改的复杂度仍然很高。参照我们进行查询的代码,有什么方法可以使我们修改的时候只修改分裂出的 O(\log n) 个区间,而不是所有叶子节点一并修改呢?

可以看到,现在左右两边黑色线条中间的这棵子树即为线段树的一部分。假设现在我们需要将这棵子树的所有叶子结点的权值修改为 k,那么需要遍历的节点(即对时间复杂度有贡献的节点)就为图中红色的根节点和绿色的节点。现在我们采用一种方法:

若需要对节点 p 及它下面所属的所有节点进行修改,则只对根节点 p 进行一次修改标记

什么意思呢?假设现在我们线段树上有一个编号为 p 的节点,所管辖的区间为 [l,r]。现在我们需要对 p 的子树进行赋值为 k 的操作。很容易可以看出,sum_p 的值将变为 (r-l+1)\times k (即该区间长度为 (r-l+1),每个位置都为 k),这时我们就停止了往下遍历,而是在另一个标记数组 tag 上对 tag_p 赋值为 k 并直接返回。

代码:

void f(int p, int k){
    sum[p] = len[p] * k;//len指的是节点p所管辖的区间的长度(即r-l+1)
    tag[p] = k;
    return;
}

那么我们如何处理剩下来的那些没有被修改过的节点呢?

观察 query 函数,我们发现,但凡一个节点被访问到了,它到根节点的路径上的节点一定也被访问到过。于是只要有一个节点被访问了,我们就把这个标记下传到它的左右儿子去。

以图一所示的线段树为例,如果我们需要对 [1,2] 进行区间赋值为 k 的操作,我们只对 sum_2 赋值为 (r-l+1)\times k 并且把 tag_2 赋值成 k。我们发现这样一来对父亲方向上的节点的 push_up 就相当于暴力赋值的效果,而并没有真正赋值到它的子树 45。而当我们需要进行 query(1, 1, n, 1, 1) (即访问 2 的左儿子 4) 的操作时,我们会发现这个函数一定会访问到 2。所以在访问到 2 时,我们就可以把它的左右儿子也进行一次 f 的操作。

代码如下:

void push_down(int p){
    if(tag[p] == -1){//这里假设初始值为-1,即没有修改的操作
        f(p << 1,tag[p]);//下传给左右儿子
        f(p << 1 | 1,tag[p]);
        tag[p] = -1;//注意清零,防止重复下传。
        //如果是加法标记,标记的初始值就是0
        //如果是乘法标记,初始值是1
    }
    return;
}

所以我们的区间赋值的代码就长这样:

void update(int p, int l, int r, int x, int y, int k){//把区间 [l,r] 赋值成 k
    if(y < l || r < x) return;
    if(x <= l && r <= x){
        f(p, k);
        return;
    }
    int mid = l + r >> 1;
    push_down(p);//下传当前节点的标记
    update(p << 1, l, mid, x, y, k);
    update(p << 1 | 1, mid + 1, r, x, y, k);
    push_up(p);
    return;
}

对应的,query 函数所对应的地方也要进行 push_down 操作。

类似的,update 的时间复杂度分析也和 query 一样,为 O(\log n)

以上就是针对区间赋值,区间查询操作的线段树。

现在看一道例题。这里需要我们支持区间加,区间查询

观察我们上面的线段树,发现唯一不一样的地方就是修改部分。这里要求我们进行的修改操作是加而不是赋值。那么我们只需要对修改操作中的 f 操作进行调整即可。

重新定义 f 操作:

p 的子树统一加上 k

发现接下来的操作其实就是对标记数组进行调整。把标记数组称为 add。如果 add_p 的值为 k,就代表需要将 p 的左右子树加上 k。于是 f 操作就变成了这样:

void f(int p, int k){//把p的子树加上k 
    sum[p] += /*注意是+=*/ len[p] * k;
    add[p] += k;
    return; 
}

其他部分没变。

再来看线段树2。这题比较复杂,需要我们在之前的基础上进行区间乘操作。对应的,我们也另外设立一个针对乘法的标记数组 mul。现在重新定义一下 f 函数:

f(p,k,op)即对 p 的子树加/乘上 k。其中 op 只为 010 表示为加操作,1 表示为乘操作。

先来看加法:

void f(int p, int k, int op){
    if(op == 0){//加法 
        sum[p] += len[p] * k;
        add[p] += k; 
    }else if(op == 1){
        //... 
    }
    return; 
}

基本没变。

接下来的关键在于如何维护乘法的标记。如果对于一棵子树乘以 k,我们只是对 sum_p 进行乘以 k 的操作吗?并不是这样。我们知道,一个节点上的值实际上是这样的:

设该节点为 p,父亲为 f\lfloor\dfrac{p}{2}\rfloor,则该节点的权值 S 应该表示为

S=sum_p\times mul_f+add_f

现在对式子两边都乘以 k,得:

S\times k=(sum_p\times mul_f+add_f)\times k=sum_p\times k\times mul_f+add_f\times k

意思就是说,当把子树 f 乘以 k 时我们不仅仅要将 sum_fmul_f 乘以 k,还要把加法标记也一并乘以 k

代码如下:

void f(int p, int k, int op){
    if(op == 0){//加法 
        sum[p] += len[p] * k;
        add[p] += k; 
    }else if(op == 1){//乘法 
        sum[p] *= k;
        add[p] *= k;
        mul[p] *= k; 
    }
    return; 
}

问题来到了 push_down 函数。现在我们有两个标记,那么,该先下传哪一个呢?

还是这个式子(当前节点编号为 p):

针对左儿子有:

a_{p<<1}=sum_{p<<1}\times mul_p+add_p

显然,根据运算法则,我们先运算乘法。理所当然的,先下传乘法标记:

void push_down(int p){
    if(mul[p] != 1){
        f(p << 1, mul[p], 1);
        f(p << 1 | 1, mul[p], 1);
        mul[p] = 1;//清空
    }
    if(add[p]){
        f(p << 1, add[p], 0);
        f(p << 1 | 1, add[p], 0);
        add[p] = 0;
    }
    return;
}

区间和取模部分请读者自行添加。