李超线段树做一些 1D 问题

· · 题解

[HNOI2008] 玩具装箱

n\le 5\times 10^4 个物品,你要把他们分成若干段。

每一段 (i,j) 的长度定义为 j-i+\sum_{k=i}^jC_k

每段的代价就是 (L-x)^2L 为给定值,x 为这段长度。

求最小。

1D 问题,那么暴力 dp 是很好设计的。

转移就是 $dp_i=\min(dp_j+(i-(j+1)+sum_i-sum_j-L)^2)$。 $sum$ 就是 $C$ 的前缀和。 不妨把这个平方拆开? $$ dp_i=\min(dp_j+(i-j-1+sum_i-sum_{j})^2+L^2-2L(i-j-1+sum_i-sum_{j})) $$ 然后考虑把 $j$ 和 $i$ 分开,然后把不含 $j$ 的放出去。 $$ dp_i=\min(-2(j+sum_j)(i+sum_i-1)+dp_j+(j+sum_j)^2+2L(j+sum_j))+(i+sum_{i}-1)^2-2L(i+sum_{i}-1)+L^2 $$ 然后不难发现,这个 $\min$ 里的似乎是个一次函数。 $y=dp_i$。 $x=i+sum_{i}-1$。 $k=-2(j+sum_j)$。 $b=dp_j+(j+sum_j)^2+2L(j+sum_j)$。 这个时候,我们就可以李超优化了。 每次计算完 $dp_j$ 后,就可以在线段树上插入 $y=-2(j+sum_j)x+dp_j+(j+sum_j)^2+2L(j+sum_j)$ 这样的直线。 然后计算 $i$ 时,就是直接看 $x=i+sum_{i}-1$ 这条直线与已经插入的哪条直线的交点最低。 ```cpp #include <bits/stdc++.h> #define int __int128 void Freopen() { freopen("", "r", stdin); freopen("", "w", stdout); } using namespace std; const int N = 5e4 + 10, M = 2e5 + 10, inf = 1e15, mod = 998244353; struct IO { #define MAXSIZE (1 << 20) #define isdigit(x) (x >= '0' && x <= '9') char buf[MAXSIZE], *p1, *p2; char pbuf[MAXSIZE], *pp; #if DEBUG #else IO() : p1(buf), p2(buf), pp(pbuf) {} ~IO() { fwrite(pbuf, 1, pp - pbuf, stdout); } #endif char gc() { #if DEBUG // 调试,可显示字符 return getchar(); #endif if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin); return p1 == p2 ? ' ' : *p1++; } bool blank(char ch) { return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t'; } template <class T> void read(T &x) { double tmp = 1; bool sign = false; x = 0; char ch = gc(); for (; !isdigit(ch); ch = gc()) if (ch == '-') sign = 1; for (; isdigit(ch); ch = gc()) x = x * 10 + (ch - '0'); if (ch == '.') for (ch = gc(); isdigit(ch); ch = gc()) tmp /= 10.0, x += tmp * (ch - '0'); if (sign) x = -x; } void read(char *s) { char ch = gc(); for (; blank(ch); ch = gc()); for (; !blank(ch); ch = gc()) *s++ = ch; *s = 0; } void read(char &c) { for (c = gc(); blank(c); c = gc()); } void push(const char &c) { #if DEBUG // 调试,可显示字符 putchar(c); #else if (pp - pbuf == MAXSIZE) fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf; *pp++ = c; #endif } template <class T> void write(T x) { if (x < 0) x = -x, push('-'); // 负数输出 static T sta[35]; T top = 0; do { sta[top++] = x % 10, x /= 10; } while (x); while (top) push(sta[--top] + '0'); } template <class T> void write(T x, char lastChar) { write(x), push(lastChar); } } io; int n, L, rt; int sum[N], dp[N]; int K( int i) { return -2 * (i + sum[i]); } int X( int i) { return i + sum[i] - 1; } int B( int i) { return dp[i] + (i + sum[i]) * (i + sum[i]) + 2 * L * (i + sum[i]); } struct sgt { #define mid ((l + r) >> 1) int tag[N * 30], ls[N * 30], rs[N * 30]; int cnt, tot; void init() { cnt = tot = 0; memset(tag, 0, sizeof tag); } struct line { int k, b; } l[N]; int cal( int id, int x) { return l[id].k * x + l[id].b; } void upd( int & k, int l, int r, int u) { if (! k) k = ++ tot; int & v = tag[k], op = cal(u, mid) - cal(v, mid); if (op < 0) swap(u, v); op = cal(u, l) - cal(v, l); if (op < 0) upd(ls[k], l, mid, u); op = cal(u, r) - cal(v, r); if (op < 0) upd(rs[k], mid + 1, r, u); } void add( int k, int b) { l[++ cnt] = {k, b}; upd(rt, 0, inf, cnt); } int ask( int k, int l, int r, int x) { if (! k) return 0; if (l == r) return tag[k]; int res = tag[k], res1; if (x <= mid) res1 = ask(ls[k], l, mid, x); else res1 = ask(rs[k], mid + 1, r, x); int op = cal(res, x) - cal(res1, x); if (op < 0) return res; return res1; } #undef mid } t; signed main() { io.read(n), io.read(L); for ( int i = 1; i <= n; i ++) io.read(sum[i]), sum[i] += sum[i - 1]; t.add(0, 0); for ( int i = 1; i <= n; i ++) { dp[i] = t.cal(t.ask(rt, 0, inf, X(i)), X(i)) + X(i) * X(i) + L * L - 2 * L * X(i); t.add(K(i), B(i)); } io.write(dp[n]); return 0; } ```