题解 P4767 【[IOI2000]邮局】

吹雪吹雪吹

2018-12-19 19:45:36

Solution

一打眼就看出来是个DP优化的题! ~~然而我不会~~ 40分:三方的DP不用解释了吧qwq 100分:四边形不等式(参见memset0巨佬的题解,蒟蒻我不会qwq) 101分:~~忘情水~~wqs二分 好的,我们来讲101分做法。 首先,一个很显然的结论:取的邮局越多,答案越优 其次,另一个比较显然的结论:设设置$k$个邮局的最优解为$f(k)$,则$f(k)-f(k-1)>f(k+1)-f(k)$。也就是说,该函数的大致图像是: ![](https://cdn.luogu.com.cn/upload/image_hosting/owkbc6qu.png) 接下来,观察原方程$f[i][j]$表示$1..i$放了$j$个邮局,发现第一维几乎不可能被优化掉,那么只能从第二维下手了。 如果能恰好取到$k$个邮局的话。。。~~你做梦!~~ 对于此题,忘情水二分就是干这个事情的。 我们先枚举一个值$C$,表示每放一个邮局需要额外花费$C$的代价。现在,函数图像变成了这样: ![](https://cdn.luogu.com.cn/upload/image_hosting/yrw379xm.png) 图中蓝线为原函数,橙线为现函数,绿色虚线长度依次是$0C,1C,2C,3C,4C...$ 可以证明,这是一个单峰函数(然而我不会证,只能感性理解) 于是我们可以二分$C$,在转移时记录当前放置邮局的次数(记录次数,不设上限),然后根据当前次数调整$L$和$R$(具体见代码) ```cpp /*xxc 18/12/19 */ /*https://xcfubuki.cn*/ #include <cstdio> #include <cstring> #include <algorithm> #define calc(x, y) (f[x] + w((x) + 1, y) + exc) #define maxn 100005 using namespace std; typedef long long LL; int n, k, a[maxn], cnt[maxn], pre[maxn]; LL s[maxn], f[maxn], exc; class jc { public: int l, r, p; } que[maxn]; inline int read() { char ch = getchar(); int ret = 0, f = 1; while (ch > '9' || ch < '0') { if (ch == '-') f = -f; ch = getchar(); } while (ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0', ch = getchar(); return ret * f; } LL w(int l, int r) { if (l >= r) return 0; int mid = (r - l >> 1) + l; LL res = a[mid] * (1ll * mid - l + 1) - (s[mid] - s[l - 1]); res += (s[r] - s[mid - 1]) - (1ll * r - mid + 1) * a[mid]; return res; } int find(jc x, int s) { int L = x.l, R = x.r; while (L <= R) { int mid = (R - L >> 1) + L; if (calc(x.p, mid) > calc(s, mid)) R = mid - 1; else L = mid + 1; } return L; } int check() { int hed = 1, til = 0; que[++til] = (jc){1, n, 0}; for (int i = 1; i <= n; ++i) { f[i] = calc(que[hed].p, i); pre[i] = que[hed].p; cnt[i] = cnt[que[hed].p] + 1; int chs = -1; while (hed <= til) { if (calc(i, que[til].l) < calc(que[til].p, que[til].l)) chs = que[til--].l; else { int st = find(que[til], i); if (st <= que[til].r) chs = st, que[til].r = st - 1; break; } } if (chs != -1) que[++til] = (jc){chs, n, i}; if (hed <= til) { que[hed].l++; if (que[hed].l > que[hed].r) hed++; } } return cnt[n]; } void output(int i) { if (0 == i) return; output(pre[i]); printf("%d ", i); } int main() { n = read(), k = read(); for (int i = 1; i <= n; ++i) a[i] = read(); sort(a + 1, a + 1 + n); for (int i = 1; i <= n; ++i) s[i] = s[i - 1] + 1ll * a[i]; LL L = 0, R = 1e6, res = 0; while (L <= R) { LL mid = (R - L >> 1) + L; exc = mid; if (check() <= k) res = mid, R = mid - 1; else L = mid + 1; } exc = res; check(); printf("%lld\n", f[n] - k * res); return 0; } ```