线段树维护区间历史版本和

· · 个人记录

Problem

给定长度 n 的正整数序列 a,你需要支持下面的操作 m 次:

以原序列为时刻 0,接下来每进行一次操作则时刻 +1

n,m \leq 3 \times 10^5,1 \leq a_i,k \leq 10^9

Solution

考虑用线段树标记永久化后维护。

我们让线段树中存这么几个变量:

其中 $add$ 和 $del$ 都是节点的标记,也就是需要永久化的变量。 然后 $Sum$ 表示当前节点的区间和,$Sumh$ 表示当前区间历史和,$tim$ 表示上一次区间和改变的时间。 首先,假设我们假设对于一个节点 $[l_i,r_i]$ 在 $t_1,t_2,t_3...t_k$ 时刻都被进行了加操作。 那么我们查询这个区间在时刻 $T$ 的值的时候,这些标记的贡献为: $(r_i - l_i + 1) \times \displaystyle \sum_{i = 1}^{k}V_i \times (T - t_i)$。 其中 $V_i$ 表示的就是第 $i$ 次加操作的值域。 要算出这个式子,我们只需要知道 $\displaystyle \sum_{i = 1}^{k} V_i, \sum_{i = 1}^{k} V_i \times t_i$ 即可。 于是我们的 $add$ 以及 $del$ 就分别表示上面的两个量。 标记永久化的情况下,我们每次锁定位置打标记后可以计算这些打标记的节点的子树的答案,但是如何直接修改一次修改中被涉及到的节点呢。 我们可以知道,它们的 $sum$ 有可能会变化,那么我们就需要将 $tim$ 修改为当前修改操作的时间戳,并且**在这之前**要将 $sumh += sum \times (T - tim)$。 这样就可以维护了。 最后如果是询问区间和就直接求,如果需要求历史和的话,求出历史和之后再加上当前的区间和就行了,因为根据我们的计算方式肯定是会漏掉当前版本的区间和的(因为是 $\times (T - tim)$ 而并没有 $+ 1$) 。 #### Code ```cpp #include <bits/stdc++.h> using namespace std; #define int long long #define Rep(i, l, r) for(int i = l; i <= r ; i ++) #define Lep(i, r, l) for(int i = r; i >= l ; i --) inline void read() {} template <typename T, typename... Args> inline void read(T& x, Args&... args) { int flag = 1; x = 0; char ch = getchar(); for( ; ch > '9' || ch < '0' ; ch = getchar()) if(ch == '-') flag = -1; for( ; ch >= '0' && ch <= '9' ; ch = getchar()) x = x * 10 + ch - '0'; x *= flag, read(args...); } const int MAXN = 3e5 + 50, Mod = 998244353; int n, m, A[MAXN], S[MAXN]; struct SegmentTree { int l, r, addh, del; int Sum, tim, Sumh; } T[MAXN << 2]; void build(int x, int l, int r) { T[x].l = l, T[x].r = r, T[x].del = T[x].addh = 0; T[x].Sumh = T[x].Sum = T[x].tim = 0; int mid = (l + r) >> 1; if(l == r) { T[x].Sum = A[l]; return ; } build(x << 1, l, mid), build(x << 1 | 1, mid + 1, r); T[x].Sum = (T[x << 1].Sum + T[x << 1 | 1].Sum) % Mod; return ; } void change(int x, int l, int r, int k, int tim) { if(T[x].l >= l && T[x].r <= r) { T[x].addh += k % Mod, T[x].del += tim * k % Mod; T[x].Sumh += T[x].Sum * (tim - T[x].tim) % Mod; T[x].Sum += (T[x].r - T[x].l + 1) * k % Mod, T[x].tim = tim; T[x].Sum %= Mod, T[x].Sumh %= Mod, T[x].addh %= Mod, T[x].del %= Mod; return ; } int mid = (T[x].l + T[x].r) >> 1, L = max(T[x].l, l), R = min(T[x].r, r); if(l <= mid) change(x << 1, l, r, k, tim); if(r > mid) change(x << 1 | 1, l, r, k, tim); T[x].Sumh += T[x].Sum * (tim - T[x].tim) % Mod, T[x].Sumh %= Mod; T[x].Sum += (R - L + 1) * k % Mod, T[x].tim = tim, T[x].Sum %= Mod; return ; } int GetSum(int x, int l, int r, int addh) { int mid = (T[x].l + T[x].r) >> 1, S = 0; if(T[x].l >= l && T[x].r <= r) { return (T[x].Sum + addh * (T[x].r - T[x].l + 1) % Mod) % Mod; } addh += T[x].addh, addh %= Mod; if(l <= mid) S += GetSum(x << 1, l, r, addh), S %= Mod; if(r > mid) S += GetSum(x << 1 | 1, l, r, addh), S %= Mod; return S; } int GetSumh(int x, int l, int r, int addh, int del, int tim) { int mid = (T[x].l + T[x].r) >> 1, S = 0; if(T[x].l >= l && T[x].r <= r) { int Len = (T[x].r - T[x].l + 1); return ( (T[x].Sumh + addh * Len % Mod * tim % Mod - del * Len % Mod + Mod) % Mod + T[x].Sum * (tim - T[x].tim) % Mod ) % Mod; } addh += T[x].addh, del += T[x].del, addh %= Mod, del %= Mod; if(l <= mid) S += GetSumh(x << 1, l, r, addh, del, tim), S %= Mod; if(r > mid) S += GetSumh(x << 1 | 1, l, r, addh, del, tim), S %= Mod; return S; } signed main() { read(n, m); Rep(i, 1, n) read(A[i]); build(1, 1, n); Rep(i, 1, m) { int op, l, r, k; read(op, l, r); if(op == 1) read(k), change(1, l, r, k, i); if(op == 2) printf("%lld\n", GetSum(1, l, r, 0)); if(op == 3) printf("%lld\n", (GetSumh(1, l, r, 0, 0, i) + GetSum(1, l, r, 0)) % Mod); } return 0; } ```