【线段树】区间修改(区间加、区间乘)标记下放顺序及具体操作
具体题目请参考洛谷P3373 线段树 2
这道题的数据规模挺大的,不开long long只能拿到30分,写的时候记得注意一下,然后我们可以开始讲线段树的部分了。
线段树同时对区间有加操作和乘操作时,它们的lazy标记下放顺序是经常被讨论的问题,当然,严格来说二者孰先孰后都是不会对答案造成影响的,无非是运算表达式的写法不同,但是在计算机中,精度将是一个不得不被考虑的问题。所以就有了这个老生常谈的事情,下面我们来细致讨论一下它们哪个先做会更优。
我们知道,lazy的原理就是将线段树上的区间操作延后进行,在遍历到有标记的区间时,这一步被标记的操作才会真正的向下进行,否则就会被存在标记中无限地被延后。
具体地,当我们需要为线段树上的节点打lazy标记时,我们需要为线段树的每一个节点都添加两个标记,分别储存它们的区间加和区间乘操作信息。然后在修改时,如果遍历到一个区间l >= tgtl && r <= tgtr则直接修改这一区间的值(tgtl和tgtr是目标区间的左右端点),并在其上打上标记,紧接着返回,返回时修改遍历路径上的信息。
更加具体地,做加法修改时要将操作数乘要修改的区间的长度,因为线段树的每个节点存的是区间内所有元素的和,即
写成代码块的形式就是这样:
if (l >= tgtl && r <= tgtr) {
if (t == 1) { //t==1为区间乘,t==2为区间加
tr[id].v = mod(tr[id].v * k);//int mod(int x)为手写取模函数
tr[id].tm = mod(tr[id].tm * k);//打上乘法标记
}
else {
tr[id].v = mod(tr[id].v + mod(k * mod(r - l + 1)));//修改权值时注意乘区间长度
tr[id].lz = mod(k + tr[id].lz);//打上加法标记
}
return;
}
这是具体的修改过程,下面我们来看看标记下放。
先说原理,我们已经在修改函数的边界处对两种操作打下了相应的lazy标记然后返回,所以当我们遍历到一个节点有lazy标记时,意味着该节点的子树仍有未被执行的操作被延后了,所以我们在遍历到某个节点时,就将这个节点的lazy标记下放到它的儿子节点上,以此来为后面的操作做准备。下放时应该将标记同时下放到它的左右儿子上,由于线段树是二分产生的,所以线段树上所有点的度数非零即二,不需要为任何节点考虑它是存在一个还是两个子节点。
更加具体地,我们来探讨一下是先下放加标记还是先下放乘标记。
-
如果我们先下放了加法标记,则有:
设当前节点的lazy加标记为lz,乘标记为tm,儿子节点的左右端点为l和r。 儿子节点的权 = ((r - l + 1) * lz + ∑a[i]) * tm 这么看还没有什么问题,继续。 我们的又一轮操作后,又一次下放到了这个点,这个节点的权需要进行更新了 此时节点的权 = ((((r - l + 1) * lz + ∑a[i]) * tm * (r - l + 1)) + lz') * tm' = ((r - l + 1) * (lz + lz'/ tm) + ∑a[i]) * tm * tm'此时我们发现,算式中出现了除法,这是我们要尽量避免的,如果式子中有除法,运算时就有可能出现精度问题;更有甚者,如果涉及到取模,还有可能需要使用乘法逆元,非常麻烦。
-
所以我们尝试先下放乘法标记,则:
设表示方法同上。 则儿子的权 = (tm * ∑a[i]) + lz * (r - l + 1) 新的一轮操作呢?又一次下放后, 儿子的权 = ((tm * ∑a[i]) + lz * (r - l + 1) * tm' + lz' * (r - l + 1) 又是乘法分配律,可以得到: 儿子的权 = (tm * tm' * ∑a[i]) + (lz + lz') * (r - l + 1)非常好,这下没有除法出现了,精度问题得到了较好地解决。
综上所述,我们优先下放乘法标记。
那么下放标记的函数呢,写成代码块的形式就是这样:
inline void putdown(int id) {
node ls = tr[id << 1]; //为了方便写,先用局部变量抽出来
node rs = tr[id << 1 | 1];
ls.v = mod(mod(ls.v * tr[id].tm) + mod(tr[id].lz * mod(ls.r - ls.l + 1)));//先下放乘法标记,然后再加法,中间需要的地方取模
rs.v = mod(mod(rs.v * tr[id].tm) + mod(tr[id].lz * mod(rs.r - rs.l + 1)));
ls.lz = mod(mod(ls.lz * tr[id].tm) + tr[id].lz);//标记单独下放
rs.lz = mod(mod(rs.lz * tr[id].tm) + tr[id].lz);
ls.tm = mod(ls.tm * tr[id].tm);
rs.tm = mod(rs.tm * tr[id].tm);
tr[id << 1 | 1] = rs; //完事记得把局部变量取出的信息放回去
tr[id << 1] = ls;
return;
}
这样下放标记的函数就写好了~
放本题AC代码:
#include<bits/stdc++.h>
using namespace std;
//线段树2
#define int long long
const int maxn = (int)1e5 + 10;
struct node {
int l, r;
int v;
int tm = 1, lz;
} tr[maxn << 2];
int a[maxn];
int n, m, p;
inline int mod(int x) {
return x - x / p * p;
}
void build(int id, int l, int r) {
tr[id].l = l;
tr[id].r = r;
if (l == r) {
tr[id].v = mod(a[l]);
return;
}
int mid = l + r >> 1;
build(id << 1, l, mid);
build(id << 1 | 1, mid + 1, r);
tr[id].v = mod(tr[id << 1].v + tr[id << 1 | 1].v);
return;
}
inline void putdown(int id) {
node ls = tr[id << 1];
node rs = tr[id << 1 | 1];
ls.v = mod(mod(ls.v * tr[id].tm) + mod(tr[id].lz * mod(ls.r - ls.l + 1)));
rs.v = mod(mod(rs.v * tr[id].tm) + mod(tr[id].lz * mod(rs.r - rs.l + 1)));
ls.lz = mod(mod(ls.lz * tr[id].tm) + tr[id].lz);
rs.lz = mod(mod(rs.lz * tr[id].tm) + tr[id].lz);
ls.tm = mod(ls.tm * tr[id].tm);
rs.tm = mod(rs.tm * tr[id].tm);
tr[id << 1 | 1] = rs;
tr[id << 1] = ls;
return;
}
int ask(int id, int tgtl, int tgtr) {
int l = tr[id].l;
int r = tr[id].r;
if (tr[id].lz || tr[id].tm != 1) {
if (l != r) putdown(id);
tr[id].lz = 0;
tr[id].tm = 1;
}
if (l >= tgtl && r <= tgtr) return tr[id].v;
int mid = l + r >> 1;
if (mid >= tgtr) return mod(ask(id << 1, tgtl, tgtr));
else if (mid < tgtl) return mod(ask(id << 1 | 1, tgtl, tgtr));
else return mod(ask(id << 1, tgtl, tgtr) + ask(id << 1 | 1, tgtl, tgtr));
}
void chg(int id, int tgtl, int tgtr, int k, int t) {
int l = tr[id].l;
int r = tr[id].r;
if (tr[id].lz || tr[id].tm != 1) {
if (l != r) putdown(id);
tr[id].lz = 0;
tr[id].tm = 1;
}
if (l >= tgtl && r <= tgtr) {
if (t == 1) {
tr[id].v = mod(tr[id].v * k);
tr[id].tm = mod(tr[id].tm * k);
}
else {
tr[id].v = mod(tr[id].v + mod(k * mod(r - l + 1)));
tr[id].lz = mod(k + tr[id].lz);
}
return;
}
int mid = l + r >> 1;
if (mid < tgtl) chg(id << 1 | 1, tgtl, tgtr, k, t);
else if (mid >= tgtr) chg(id << 1, tgtl, tgtr, k, t);
else {
chg(id << 1, tgtl, tgtr, k, t);
chg(id << 1 | 1, tgtl, tgtr, k, t);
}
tr[id].v = mod(tr[id << 1].v + tr[id << 1 | 1].v);
return;
}
signed main() {
scanf("%lld %lld %lld", &n, &m, &p);
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int opt, x, y, z;
scanf("%lld %lld %lld", &opt, &x, &y);
if (opt == 1) {
scanf("%lld", &z);
chg(1, x, y, mod(z), 1);
}
else if (opt == 2) {
scanf("%lld", &z);
chg(1, x, y, mod(z), 2);
}
else printf("%lld\n", ask(1, x, y));
}
return 0;
}