题解 P4767 【[IOI2000]邮局】
吹雪吹雪吹
2018-12-19 19:45:36
一打眼就看出来是个DP优化的题!
~~然而我不会~~
40分:三方的DP不用解释了吧qwq
100分:四边形不等式(参见memset0巨佬的题解,蒟蒻我不会qwq)
101分:~~忘情水~~wqs二分
好的,我们来讲101分做法。
首先,一个很显然的结论:取的邮局越多,答案越优
其次,另一个比较显然的结论:设设置$k$个邮局的最优解为$f(k)$,则$f(k)-f(k-1)>f(k+1)-f(k)$。也就是说,该函数的大致图像是:
![](https://cdn.luogu.com.cn/upload/image_hosting/owkbc6qu.png)
接下来,观察原方程$f[i][j]$表示$1..i$放了$j$个邮局,发现第一维几乎不可能被优化掉,那么只能从第二维下手了。
如果能恰好取到$k$个邮局的话。。。~~你做梦!~~
对于此题,忘情水二分就是干这个事情的。
我们先枚举一个值$C$,表示每放一个邮局需要额外花费$C$的代价。现在,函数图像变成了这样:
![](https://cdn.luogu.com.cn/upload/image_hosting/yrw379xm.png)
图中蓝线为原函数,橙线为现函数,绿色虚线长度依次是$0C,1C,2C,3C,4C...$
可以证明,这是一个单峰函数(然而我不会证,只能感性理解)
于是我们可以二分$C$,在转移时记录当前放置邮局的次数(记录次数,不设上限),然后根据当前次数调整$L$和$R$(具体见代码)
```cpp
/*xxc 18/12/19 */
/*https://xcfubuki.cn*/
#include <cstdio>
#include <cstring>
#include <algorithm>
#define calc(x, y) (f[x] + w((x) + 1, y) + exc)
#define maxn 100005
using namespace std;
typedef long long LL;
int n, k, a[maxn], cnt[maxn], pre[maxn];
LL s[maxn], f[maxn], exc;
class jc
{
public:
int l, r, p;
} que[maxn];
inline int read()
{
char ch = getchar();
int ret = 0, f = 1;
while (ch > '9' || ch < '0')
{
if (ch == '-')
f = -f;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
ret = ret * 10 + ch - '0', ch = getchar();
return ret * f;
}
LL w(int l, int r)
{
if (l >= r)
return 0;
int mid = (r - l >> 1) + l;
LL res = a[mid] * (1ll * mid - l + 1) - (s[mid] - s[l - 1]);
res += (s[r] - s[mid - 1]) - (1ll * r - mid + 1) * a[mid];
return res;
}
int find(jc x, int s)
{
int L = x.l, R = x.r;
while (L <= R)
{
int mid = (R - L >> 1) + L;
if (calc(x.p, mid) > calc(s, mid))
R = mid - 1;
else
L = mid + 1;
}
return L;
}
int check()
{
int hed = 1, til = 0;
que[++til] = (jc){1, n, 0};
for (int i = 1; i <= n; ++i)
{
f[i] = calc(que[hed].p, i);
pre[i] = que[hed].p;
cnt[i] = cnt[que[hed].p] + 1;
int chs = -1;
while (hed <= til)
{
if (calc(i, que[til].l) < calc(que[til].p, que[til].l))
chs = que[til--].l;
else
{
int st = find(que[til], i);
if (st <= que[til].r)
chs = st, que[til].r = st - 1;
break;
}
}
if (chs != -1)
que[++til] = (jc){chs, n, i};
if (hed <= til)
{
que[hed].l++;
if (que[hed].l > que[hed].r)
hed++;
}
}
return cnt[n];
}
void output(int i)
{
if (0 == i)
return;
output(pre[i]);
printf("%d ", i);
}
int main()
{
n = read(), k = read();
for (int i = 1; i <= n; ++i)
a[i] = read();
sort(a + 1, a + 1 + n);
for (int i = 1; i <= n; ++i)
s[i] = s[i - 1] + 1ll * a[i];
LL L = 0, R = 1e6, res = 0;
while (L <= R)
{
LL mid = (R - L >> 1) + L;
exc = mid;
if (check() <= k)
res = mid, R = mid - 1;
else
L = mid + 1;
}
exc = res;
check();
printf("%lld\n", f[n] - k * res);
return 0;
}
```