题解:AT_abc455_f [ABC455F] Merge Slimes 2
yuanweic
·
·
题解
luogu 传送门:题目
Atcoder 传送门:题目
第一次赛时做出 F!
(本文用 DeepSeek 润色)
题意简述:
维护一个长度为 N 的数组 A,初始全为 0。
有 Q 次操作:
区间加:将 [l,r] 内的每个数加上 a。
查询:取出 [l,r] 内的所有数作为史莱姆的重量,求合并所有史莱姆的最小总代价。
其中,合并两个重量为 x 和 y 的史莱姆花费 x \times y,合并 M 个数需要 M - 1 次操作。
读题即知线段树。
关键结论:
对于一组数 B_1, B_2, \dots, B_M,最小合并代价为:
ans = \frac{\left(\sum_{i = 1}^{M} B_{i} \right)^2 - \sum_{i = 1}^{M} B_{i}^{2}}{2}
证明思路:
任意合并顺序的总代价等于 \sum_{1 \le i < j \le M} B_{i} B_{j}。
这是因为每对史莱姆 (i,j) 恰好会在某次合并中相遇一次,贡献 B_{i} \times B_{j}。
由恒等式 (\sum B_{i})^{2} = \sum B_{i}^{2} + 2\sum_{i < j} B_{i}B_{j} 得证。
现在可以将问题转化,每个查询只需要知道:
答案即为 (sum^{2} - sum_{2}) /2 \bmod 998244353。
但因为 C++ 中的 / 是整数除法(向下取整),而不是模意义下的除法,它无法处理模运算中的分数和取模后的值。所以不能直接除,要乘上 2 的逆元,即乘 (mod + 1) / 2 或 499122177。
实现:
使用线段树 + 懒标记,每个节点维护:
$sum_{2}$:区间平方和。
$tag$:懒标记。
区间加更新公式:
区间 $[l,r]$ 内每个数加上 $v$:
区间和:$S^{'} = S + \text{len} \times v$。
区间平方和:$S_2^{'} = \sum (x + v)^{2} = \sum (x^{2} + 2vx + v^{2}) = S_2 + 2vS + len \times v^{2}$。
时间复杂度:
建树:$O(N)$。
每次操作:$O(\log N)$。
总复杂度:$O(N + Q\log N)$。
::::success[AC Code]
```
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 998244353;
const int inv2 = (mod + 1) / 2;
const int N = 1e5 + 5;
int n,q,a[N];
struct node{
int sum,sum2,tag;
}tr[N << 2];
void build(int i,int l,int r){
tr[i].sum = 0,tr[i].sum2 = 0,tr[i].tag = 0;
if(l == r){
return;
}
int mid = (l + r) >> 1;
build(i << 1,l,mid);
build(i << 1 | 1,mid + 1,r);
}
void pushdown(int i,int l,int r){
if(tr[i].tag != 0){
int mid = (l + r) >> 1;
int l1 = mid - l + 1;
int r1 = r - mid;
int v = tr[i].tag;
int suml = tr[i << 1].sum;
tr[i << 1].sum = (tr[i << 1].sum + l1 * v) % mod;
tr[i << 1].sum2 = (tr[i << 1].sum2 + 2 * v % mod * suml % mod + l1 * v % mod * v % mod) % mod;
tr[i << 1].tag = (tr[i << 1].tag + v) % mod;
int sumr = tr[i << 1 | 1].sum;
tr[i << 1 | 1].sum = (tr[i << 1 | 1].sum + r1 * v) % mod;
tr[i << 1 | 1].sum2 = (tr[i << 1 | 1].sum2 + 2 * v % mod * sumr % mod + r1 * v % mod * v % mod) % mod;
tr[i << 1 | 1].tag = (tr[i << 1 | 1].tag + v) % mod;
}
tr[i].tag = 0;
}
void modify(int i,int l,int r,int ql,int qr,int v){
if(ql <= l && r <= qr){
int len = r - l + 1;
int sum0 = tr[i].sum;
tr[i].sum = (tr[i].sum + len * v) % mod;
tr[i].sum2 = (tr[i].sum2 + 2 * v % mod * sum0 % mod + len * v % mod * v % mod) % mod;
tr[i].tag = (tr[i].tag + v) % mod;
return;
}
pushdown(i,l,r);
int mid = (l + r) >> 1;
if(ql <= mid){
modify(i << 1,l,mid,ql,qr,v);
}
if(qr > mid){
modify(i << 1 | 1,mid + 1,r,ql,qr,v);
}
tr[i].sum = (tr[i << 1].sum + tr[i << 1 | 1].sum) % mod;
tr[i].sum2 = (tr[i << 1].sum2 + tr[i << 1 | 1].sum2) % mod;
}
void query(int i,int l,int r,int ql,int qr,int &sum,int &sum2){
if(ql <= l && r <= qr){
sum = (sum + tr[i].sum) % mod;
sum2 = (sum2 + tr[i].sum2) % mod;
return;
}
pushdown(i,l,r);
int mid = (l + r) >> 1;
if(ql <= mid){
query(i << 1,l,mid,ql,qr,sum,sum2);
}
if(qr > mid){
query(i << 1 | 1,mid + 1,r,ql,qr,sum,sum2);
}
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin >> n >> q;
build(1,1,n);
while(q--){
int l,r,a;
cin >> l >> r >> a;
a %= mod;
modify(1,1,n,l,r,a);
int sum = 0,sum2 = 0;
query(1,1,n,l,r,sum,sum2);
cout << (sum * sum % mod - sum2 + mod) % mod * inv2 % mod << '\n';
}
return 0;
}
```
::::
[AC 记录](https://atcoder.jp/contests/abc455/submissions/75256981)