【线段树】区间修改(区间加、区间乘)标记下放顺序及具体操作

· · 个人记录

具体题目请参考洛谷P3373 线段树 2

这道题的数据规模挺大的,不开long long只能拿到30分,写的时候记得注意一下,然后我们可以开始讲线段树的部分了。

线段树同时对区间有加操作和乘操作时,它们的lazy标记下放顺序是经常被讨论的问题,当然,严格来说二者孰先孰后都是不会对答案造成影响的,无非是运算表达式的写法不同,但是在计算机中,精度将是一个不得不被考虑的问题。所以就有了这个老生常谈的事情,下面我们来细致讨论一下它们哪个先做会更优。

我们知道,lazy的原理就是将线段树上的区间操作延后进行,在遍历到有标记的区间时,这一步被标记的操作才会真正的向下进行,否则就会被存在标记中无限地被延后。

具体地,当我们需要为线段树上的节点打lazy标记时,我们需要为线段树的每一个节点都添加两个标记,分别储存它们的区间加和区间乘操作信息。然后在修改时,如果遍历到一个区间[l,r]满足l >= tgtl && r <= tgtr直接修改这一区间的值(tgtl和tgtr是目标区间的左右端点),并在其上打上标记,紧接着返回,返回时修改遍历路径上的信息

更加具体地,做加法修改时要将操作数乘要修改的区间的长度,因为线段树的每个节点存的是区间内所有元素的和,即\sum_{i=l}^ra[i],那么修改后序列上的每个a[i] (i∈[l,r])都会加上操作参数k,即\sum_{i=l}^r(a[i]+k),由乘法分配律可得每个节点的权值可直接更改为(r-l+1)\times k+\sum_{i=l}^ra[i];再来看区间乘法,将区间上的所有数字都乘k,即令区间[l,r]的权变为\sum_{i=l}^ra[i]\times k,由乘法分配律可知每个节点的权可直接更改为k\sum_{i=l}^ra[i]

写成代码块的形式就是这样:

if (l >= tgtl && r <= tgtr) {
   if (t == 1) {    //t==1为区间乘,t==2为区间加
      tr[id].v = mod(tr[id].v * k);//int mod(int x)为手写取模函数
      tr[id].tm = mod(tr[id].tm * k);//打上乘法标记
   }
   else {
      tr[id].v = mod(tr[id].v + mod(k * mod(r - l + 1)));//修改权值时注意乘区间长度
      tr[id].lz = mod(k + tr[id].lz);//打上加法标记
   }
   return;
}

这是具体的修改过程,下面我们来看看标记下放。

先说原理,我们已经在修改函数的边界处对两种操作打下了相应的lazy标记然后返回,所以当我们遍历到一个节点有lazy标记时,意味着该节点的子树仍有未被执行的操作被延后了,所以我们在遍历到某个节点时,就将这个节点的lazy标记下放到它的儿子节点上,以此来为后面的操作做准备。下放时应该将标记同时下放到它的左右儿子上,由于线段树是二分产生的,所以线段树上所有点的度数非零即二,不需要为任何节点考虑它是存在一个还是两个子节点。

更加具体地,我们来探讨一下是先下放加标记还是先下放乘标记。

综上所述,我们优先下放乘法标记。

那么下放标记的函数呢,写成代码块的形式就是这样:

inline void putdown(int id) {
    node ls = tr[id << 1];      //为了方便写,先用局部变量抽出来
    node rs = tr[id << 1 | 1];
    ls.v = mod(mod(ls.v * tr[id].tm) + mod(tr[id].lz * mod(ls.r - ls.l + 1)));//先下放乘法标记,然后再加法,中间需要的地方取模
    rs.v = mod(mod(rs.v * tr[id].tm) + mod(tr[id].lz * mod(rs.r - rs.l + 1)));
    ls.lz = mod(mod(ls.lz * tr[id].tm) + tr[id].lz);//标记单独下放
    rs.lz = mod(mod(rs.lz * tr[id].tm) + tr[id].lz);
    ls.tm = mod(ls.tm * tr[id].tm);
    rs.tm = mod(rs.tm * tr[id].tm);
    tr[id << 1 | 1] = rs;   //完事记得把局部变量取出的信息放回去
    tr[id << 1] = ls;
    return;
}

这样下放标记的函数就写好了~

放本题AC代码:

#include<bits/stdc++.h>
using namespace std;
//线段树2
#define int long long
const int maxn = (int)1e5 + 10;
struct node {
    int l, r;
    int v;
    int tm = 1, lz;
} tr[maxn << 2];
int a[maxn];
int n, m, p;

inline int mod(int x) {
    return x - x / p * p;
}

void build(int id, int l, int r) {
    tr[id].l = l;
    tr[id].r = r;
    if (l == r) {
        tr[id].v = mod(a[l]);
        return;
    }
    int mid = l + r >> 1;
    build(id << 1, l, mid);
    build(id << 1 | 1, mid + 1, r);
    tr[id].v = mod(tr[id << 1].v + tr[id << 1 | 1].v);
    return;
}

inline void putdown(int id) {
    node ls = tr[id << 1];
    node rs = tr[id << 1 | 1];
    ls.v = mod(mod(ls.v * tr[id].tm) + mod(tr[id].lz * mod(ls.r - ls.l + 1)));
    rs.v = mod(mod(rs.v * tr[id].tm) + mod(tr[id].lz * mod(rs.r - rs.l + 1)));
    ls.lz = mod(mod(ls.lz * tr[id].tm) + tr[id].lz);
    rs.lz = mod(mod(rs.lz * tr[id].tm) + tr[id].lz);
    ls.tm = mod(ls.tm * tr[id].tm);
    rs.tm = mod(rs.tm * tr[id].tm);
    tr[id << 1 | 1] = rs;
    tr[id << 1] = ls;
    return;
}

int ask(int id, int tgtl, int tgtr) {
    int l = tr[id].l;
    int r = tr[id].r;
    if (tr[id].lz || tr[id].tm != 1) {
        if (l != r) putdown(id);
        tr[id].lz = 0;
        tr[id].tm = 1;
    }
    if (l >= tgtl && r <= tgtr) return tr[id].v;
    int mid = l + r >> 1;
    if (mid >= tgtr) return mod(ask(id << 1, tgtl, tgtr));
    else if (mid < tgtl) return mod(ask(id << 1 | 1, tgtl, tgtr));
    else return mod(ask(id << 1, tgtl, tgtr) + ask(id << 1 | 1, tgtl, tgtr));
}

void chg(int id, int tgtl, int tgtr, int k, int t) {
    int l = tr[id].l;
    int r = tr[id].r;
    if (tr[id].lz || tr[id].tm != 1) {
        if (l != r) putdown(id);
        tr[id].lz = 0;
        tr[id].tm = 1;
    }
    if (l >= tgtl && r <= tgtr) {
        if (t == 1) {
            tr[id].v = mod(tr[id].v * k);
            tr[id].tm = mod(tr[id].tm * k);
        }
        else {
            tr[id].v = mod(tr[id].v + mod(k * mod(r - l + 1)));
            tr[id].lz = mod(k + tr[id].lz);
        }
        return;
    }
    int mid = l + r >> 1;
    if (mid < tgtl) chg(id << 1 | 1, tgtl, tgtr, k, t);
    else if (mid >= tgtr) chg(id << 1, tgtl, tgtr, k, t);
    else {
        chg(id << 1, tgtl, tgtr, k, t);
        chg(id << 1 | 1, tgtl, tgtr, k, t);
    }
    tr[id].v = mod(tr[id << 1].v + tr[id << 1 | 1].v);
    return;
}

signed main() {
    scanf("%lld %lld %lld", &n, &m, &p);
    for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);

    build(1, 1, n);

    for (int i = 1; i <= m; i++) {
        int opt, x, y, z;
        scanf("%lld %lld %lld", &opt, &x, &y);
        if (opt == 1) {
            scanf("%lld", &z);
            chg(1, x, y, mod(z), 1);
        }
        else if (opt == 2) {
            scanf("%lld", &z);
            chg(1, x, y, mod(z), 2);
        }
        else printf("%lld\n", ask(1, x, y));
    }

    return 0;
}