题解 P3648 【[APIO2014]序列分割】

· · 题解

题目链接

看到还没有wqs二分的题解,于是我就来写一篇qwq

首先,答案与分割顺序无关,这个可以这么证明:对于数列abca, b, c是要分割的区间),把它分开有2种方案,先分ab或先分bc,而ab+(a+b)c = a(b+c)+bc = ab+ac+bc,所以最终答案相同。

于是我们可以假设我们是从后往前分割的。

然后我们考虑dp,设f[i][p]表示当前处理了前i个数,分了p块的最大得分,转移的时候枚举上一个分割点为j。设序列a[i]的前缀和为s[i],则方程为

f[i][p] = \max_{j < i}\{f[j][p-1]+s[j] \times (s[i]-s[j])\}

初值为f[i][0] = 0

复杂度为O(n^2k)

考虑如何优化复杂度,显然状态的第一维表示位置无法优化,我们希望能让状态去掉第二维。

这里介绍一种方法叫wqs二分(也叫凸优化或带权二分),它是用来解决“恰好选k个元素使答案最优”之类的问题的,一般配合贪心或dp来使用。

考虑题目所满足的性质,可以发现,当分的块数越多时最终得分也就越大。那么我们就要给它加一个限制使它在分的块数比较多的时候答案变小。

我们可以让它每用一次分割都会付出cost的代价。显然当cost很大时,一次也不分割是最优的。

进一步发现,当cost变化时,答案取最大时分割的次数是单调的。于是我们可以二分这个cost

然后我们考虑求出在有cost条件下的最优答案。设f[i]为处理完前i个数的最优答案,转移的时候枚举上一个分割点为j。方程为:

f[i] = \max_{j < i}\{f[j] + s[j] \times (s[i]-s[j])-cost\}

看起来很像一个斜率优化,我们试着化一下方程:

f[i] = f[j] + s[j] \times (s[i]-s[j])-cost f[i] = f[j] + s[j] \times s[i] - s[j]^2-cost f[j] - s[j]^2 = f[i] - s[j] \times s[i] + cost

y(i) = f[i] - s[i]^2, x(i) = s[i], k(i) = s[i],我们发现这是斜率优化的一般形式。于是我们就可以O(1)转移了。

注意当f[i] < 0时,前面不分割更优,此时f[i]应为0,这种情况需要判掉。

然后每次转移的时候记录g[i]为当前分割的次数,最终我们要求的答案为f[i]+cost \times g[i]。我们通过二分cost使g[i]最终取到k

另外这题需要输出方案,我们在转移的时候记录一下前驱即可。

下面放代码:

#include <cstdio>
#include <cctype>
#define maxn 100005
typedef long long ll;
inline int read() {
    int d=0;char ch=getchar();while(!isdigit(ch))ch=getchar();
    while(isdigit(ch)){d=d*10+ch-48;ch=getchar();}return d;
}

int n, k;
ll s[maxn];
ll f[maxn];
int g[maxn];
ll ls, rs, mid, ans;

inline ll K(int i) {return s[i];}
inline ll X(int i) {return s[i];}
inline ll Y(int i) {return s[i]*s[i]-f[i];}
inline double slp(int i, int j) {return X(i) == X(j) ? 1e18 : ((double)Y(i)-Y(j))/((double)X(i)-X(j));}

int que[maxn], he, ta;

int check(ll cst) {
    que[he = ta = 0] = 1;
    f[1] = g[1] = 0;
    for(int i = 2; i <= n; ++i) {
        while(he < ta && slp(que[he], que[he+1]) < K(i)) ++he;
        f[i] = f[que[he]] + s[que[he]] * (s[i] - s[que[he]]) - cst, g[i] = g[que[he]] + 1;
        if(f[i] < 0) f[i] = g[i] = 0;
        while(he < ta && slp(que[ta], que[ta-1]) > slp(que[ta-1], i)) --ta;
        que[++ta] = i;
    }
    return g[n];
}

int main() {
    n = read(), k = read();
    for(int i = 1; i <= n; ++i)
        s[i] = read() + s[i-1];
    ls = 0, rs = 1e18;
    while(ls <= rs) {
        mid = (ls+rs)>>1;
        if(check(mid) >= k) ls = mid+1, ans = mid;
        else rs = mid-1;
    }
    que[he = ta = 0] = 1;
    f[1] = g[1] = 0;
    for(int i = 2; i <= n; ++i) {
        while(he < ta && slp(que[he], que[he+1]) < K(i)) ++he;
        f[i] = f[que[he]] + s[que[he]] * (s[i] - s[que[he]]) - ans, g[i] = que[he];
        if(f[i] < 0) f[i] = g[i] = 0;
        while(he < ta && slp(que[ta], que[ta-1]) > slp(que[ta-1], i)) --ta;
        que[++ta] = i;
    }
    printf("%lld\n", f[n] + ans*k);
    for(int now = n; g[now]; now = g[now])
        printf("%d ", g[now]);
    return 0;
}