[线段树] [分段函数] P5609 [Ynoi2013] 对数据结构的爱

· · 题解

尝试更自然地叙述思路。唯一一处跳跃性步骤也是符合人类直觉的。

题意:给你长为 n 的操作序列 a 及定值 pm 次询问,每次给出 [l,r,x],求出 x 依次进行操作 [l,r] 后的值。一次操作会使得 x\gets x+a_i,然后若 x\ge px\gets x-pn\le 10^6,m\le 2\times 10^5

不妨从函数的视角出发,容易发现这是一个由若干段 k=1 线段构成的分段函数。显然初值越大就会被减去更多次 p,于是这个分段函数可以用一个序列 x_i 描述,初值在 [x_i,x_{i+1}) 都满足恰被减去 ip

由于 |x| 是操作个数级别的,很自然想到上线段树,对每个区间维护其分段函数。假如能预处理出 x,那么每次查询只需拆分区间后,依次二分得到会被减去几次 p 即可 O(\log^2 n) 算出答案。

重点是怎么实现 push_up 函数。考虑朴素合并,对于左边区间 i,若其终值区间覆盖了右边区间 j 的左端点,则可以向 x_{i+j} 贡献,记这个贡献是 f(i,j),显然 f(i,j)\in [x_i,x_{i+1}),于是 f(i,j)<f(i+1,j-1),于是 j 最大即可取到最小值。

(为了方便,不妨放宽限制化区间为前缀,只不过贡献时与 a_i 取个 \max 即可。)

合法 i,j 判定条件即为:x_{i+1}-1+s_{lson}-ip\ge x_j。于是 i 合法的 j 是段前缀。同时合理猜想 i 的合法 j 前缀边界随着 i 增大而增大,即 j 合法的 i 是段后缀。

那么直接双指针即可,复杂度为 O(n\log n)

证明:等价于 x_i-ip 递增,接下来证 x_{i+1}-x_i\ge p。你考虑代初值为 x_i 时,恰好做完 i 次减法后值必然为 0 且之后值不能超过 0,不然 x_i 可取更小,因此至少再加上 p 才能多减一次。

总结:本题运用线段树维护分段函数,重点为合并儿子信息。需要一定的直觉。

代码

#include<bits/stdc++.h>
using namespace std;

#define int long long
#define MIN(a, b) a = min(a, b)
const int N = 1e6 + 5, inf = 1e17, M = 3e7 + 5;
int n, m, p, a[N], now, l, r, lstans;

namespace SGT{
    #define lt (u << 1)
    #define rt (u << 1 | 1)
    #define mid (l + r >> 1)
    int s[N << 2], x[M], beg[N << 2], siz[N << 2], tot;

    inline void build(int u, int l, int r){
        beg[u] = tot, siz[u] = r - l + 3; tot += siz[u];
        for(int i = 0; i <= r - l + 2; ++i) x[beg[u] + i] = inf;
        if(l == r) {x[beg[u]] = -inf, x[beg[u] + 1] = p - a[l], s[u] = a[l]; return ;}
        build(lt, l, mid), build(rt, mid + 1, r);
        s[u] = s[lt] + s[rt];
        int lsiz = mid - l + 1, rsiz = r - mid;
        for(int i = 0, j = 0; i <= lsiz; ++i){
            while(1){
                MIN(x[beg[u] + i + j], max(x[beg[lt] + i], x[beg[rt] + j] - s[lt] + i * p));
                ++j;
                if(j > rsiz) {j = rsiz; break;}
                if(x[beg[lt] + i + 1] - 1 + s[lt] - i * p < x[beg[rt] + j]) {--j; break;}    
            }
        }
    }
    inline void fid(int u, int l, int r, int ll, int rr){
        if(ll <= l && r <= rr){
            int val = upper_bound(x + beg[u], x + beg[u] + siz[u], now) - (x + beg[u]) - 1;
            now = now + s[u] - val * p;
            return ;
        }
        if(ll <= mid) fid(lt, l, mid, ll, rr);
        if(rr > mid) fid(rt, mid + 1, r, ll, rr);
    }
}using namespace SGT;
signed main(){
    cin >> n >> m >> p;
    for(int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
    build(1, 1, n);
    while(m--){
        scanf("%lld%lld%lld", &l, &r, &now);
        l ^= lstans, r ^= lstans, now ^= lstans;
        assert(l <= r);
        fid(1, 1, n, l, r);
        printf("%lld\n", now);
        lstans = (now % n + n) % n;
    }    
    return 0;
}