李超线段树做一些 1D 问题
Whiking
·
·
题解
[HNOI2008] 玩具装箱
有 n\le 5\times 10^4 个物品,你要把他们分成若干段。
每一段 (i,j) 的长度定义为 j-i+\sum_{k=i}^jC_k。
每段的代价就是 (L-x)^2,L 为给定值,x 为这段长度。
求最小。
1D 问题,那么暴力 dp 是很好设计的。
转移就是 $dp_i=\min(dp_j+(i-(j+1)+sum_i-sum_j-L)^2)$。
$sum$ 就是 $C$ 的前缀和。
不妨把这个平方拆开?
$$
dp_i=\min(dp_j+(i-j-1+sum_i-sum_{j})^2+L^2-2L(i-j-1+sum_i-sum_{j}))
$$
然后考虑把 $j$ 和 $i$ 分开,然后把不含 $j$ 的放出去。
$$
dp_i=\min(-2(j+sum_j)(i+sum_i-1)+dp_j+(j+sum_j)^2+2L(j+sum_j))+(i+sum_{i}-1)^2-2L(i+sum_{i}-1)+L^2
$$
然后不难发现,这个 $\min$ 里的似乎是个一次函数。
$y=dp_i$。
$x=i+sum_{i}-1$。
$k=-2(j+sum_j)$。
$b=dp_j+(j+sum_j)^2+2L(j+sum_j)$。
这个时候,我们就可以李超优化了。
每次计算完 $dp_j$ 后,就可以在线段树上插入 $y=-2(j+sum_j)x+dp_j+(j+sum_j)^2+2L(j+sum_j)$ 这样的直线。
然后计算 $i$ 时,就是直接看 $x=i+sum_{i}-1$ 这条直线与已经插入的哪条直线的交点最低。
```cpp
#include <bits/stdc++.h>
#define int __int128
void Freopen() {
freopen("", "r", stdin);
freopen("", "w", stdout);
}
using namespace std;
const int N = 5e4 + 10, M = 2e5 + 10, inf = 1e15, mod = 998244353;
struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
char buf[MAXSIZE], *p1, *p2;
char pbuf[MAXSIZE], *pp;
#if DEBUG
#else
IO() : p1(buf), p2(buf), pp(pbuf) {}
~IO() { fwrite(pbuf, 1, pp - pbuf, stdout); }
#endif
char gc() {
#if DEBUG // 调试,可显示字符
return getchar();
#endif
if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
return p1 == p2 ? ' ' : *p1++;
}
bool blank(char ch) {
return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';
}
template <class T>
void read(T &x) {
double tmp = 1;
bool sign = false;
x = 0;
char ch = gc();
for (; !isdigit(ch); ch = gc())
if (ch == '-') sign = 1;
for (; isdigit(ch); ch = gc()) x = x * 10 + (ch - '0');
if (ch == '.')
for (ch = gc(); isdigit(ch); ch = gc())
tmp /= 10.0, x += tmp * (ch - '0');
if (sign) x = -x;
}
void read(char *s) {
char ch = gc();
for (; blank(ch); ch = gc());
for (; !blank(ch); ch = gc()) *s++ = ch;
*s = 0;
}
void read(char &c) { for (c = gc(); blank(c); c = gc()); }
void push(const char &c) {
#if DEBUG // 调试,可显示字符
putchar(c);
#else
if (pp - pbuf == MAXSIZE) fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
*pp++ = c;
#endif
}
template <class T>
void write(T x) {
if (x < 0) x = -x, push('-'); // 负数输出
static T sta[35];
T top = 0;
do {
sta[top++] = x % 10, x /= 10;
} while (x);
while (top) push(sta[--top] + '0');
}
template <class T>
void write(T x, char lastChar) {
write(x), push(lastChar);
}
} io;
int n, L, rt;
int sum[N], dp[N];
int K( int i) {
return -2 * (i + sum[i]);
}
int X( int i) {
return i + sum[i] - 1;
}
int B( int i) {
return dp[i] + (i + sum[i]) * (i + sum[i]) + 2 * L * (i + sum[i]);
}
struct sgt {
#define mid ((l + r) >> 1)
int tag[N * 30], ls[N * 30], rs[N * 30];
int cnt, tot;
void init() {
cnt = tot = 0;
memset(tag, 0, sizeof tag);
}
struct line {
int k, b;
} l[N];
int cal( int id, int x) {
return l[id].k * x + l[id].b;
}
void upd( int & k, int l, int r, int u) {
if (! k) k = ++ tot;
int & v = tag[k], op = cal(u, mid) - cal(v, mid);
if (op < 0) swap(u, v);
op = cal(u, l) - cal(v, l);
if (op < 0) upd(ls[k], l, mid, u);
op = cal(u, r) - cal(v, r);
if (op < 0) upd(rs[k], mid + 1, r, u);
}
void add( int k, int b) {
l[++ cnt] = {k, b};
upd(rt, 0, inf, cnt);
}
int ask( int k, int l, int r, int x) {
if (! k) return 0;
if (l == r) return tag[k];
int res = tag[k], res1;
if (x <= mid) res1 = ask(ls[k], l, mid, x);
else res1 = ask(rs[k], mid + 1, r, x);
int op = cal(res, x) - cal(res1, x);
if (op < 0) return res;
return res1;
}
#undef mid
} t;
signed main() {
io.read(n), io.read(L);
for ( int i = 1; i <= n; i ++)
io.read(sum[i]), sum[i] += sum[i - 1];
t.add(0, 0);
for ( int i = 1; i <= n; i ++) {
dp[i] = t.cal(t.ask(rt, 0, inf, X(i)), X(i)) + X(i) * X(i) + L * L - 2 * L * X(i);
t.add(K(i), B(i));
}
io.write(dp[n]);
return 0;
}
```