P3648 [APIO2014]序列分割
斯德哥尔摩
2018-08-04 19:34:30
[P3648 [APIO2014]序列分割](https://www.luogu.org/problemnew/show/P3648)
额,这题肯定$DP$。。。
首先对于一个数列,若要将其分为三段,总和分别是$a,b,c$。有两种分法:
$$a(b+c)+bc=ab+ac+bca(b+c)+bc=ab+ac+bc$$
$$(a+b)c+ab=ab+ac+bc(a+b)c+ab=ab+ac+bc$$
经过推广,可以发现切的顺序对于结果没有影响。
既然如此,我们就可以方便地定义状态$dp[p][i]$表示在前$i$之中切$p$刀的答案,则有:
$$dp[p][i]=\max_{j<i}{(dp[p-1][j]+sum[j]\times(sum[i]-sum[j]))}$$
至于当前方案的贡献为$sum[j]\times(sum[i]-sum[j])$的原因,因为顺序对答案不影响,对于现在$dp$到$i$时,可以认为在还未进行$dp$的$[i+1,n]$的范围是已经切好了的,而之前$dp$出的最优方案是以后才切的,即从后往前切。
不妨假令$j$为这$[1,i]$中的第一刀,那么这一刀对答案的贡献即为:$sum[j]\times(sum[i]-sum[j])$。
这个朴素DP的时间复杂度为$O(n^2k)$,空间复杂度为$O(nk)$。
铁定$TLE+MLE$。。。
容易想到每一次的状态转移只与上一次有关,那么用滚动数组压成两维即可。
其次,方程中有前缀和,而数列为非负整数,那么前缀和是单调递增,容易想到斜率优化。
为了方便表达,现在省去$dp$数组的第一维,大家可以默认为下方公式中的$dp$数组指上一次$dp$的状态,即$dp[p-1]$。
令现在有$j,k$两个位置,满足$j>k$,我们不妨假设选择$k$要优于选择$j$。
那么需要满足:
$$dp[j]+sum[j]\times(sum[i]-sum[j])\leq dp[k]+sum[k]\times(sum[i]-sum[k])$$
$$\Rightarrow sum[i]\leq \frac{dp[k]-sum[k]^2-dp[j]+sum[j]^2}{sum[j]-sum[k]}$$
然后由此算斜率优化。
但是有一个坑点,就是数列为非负整数,因此有可能会出现$sum[j]-sum[k]==0$,而除以一个$0$就$RE+WA$了!
而对于这种情况,可以知道不管在不在这个地方切开对答案应该没有影响,所以特判一下就好了
附代码:
```cpp
#include<iostream>
#include<algorithm>
#include<cstdio>
#define MAXN 100010
#define MAX (1LL<<60)
using namespace std;
int n,k,next[210][MAXN];
int head,tail,que[MAXN];
long long sum[MAXN],dp[2][MAXN];
inline int read(){
int date=0,w=1;char c=0;
while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();}
while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();}
return date*w;
}
inline double slope(int i,int x,int y){
if(sum[x]==sum[y])return (double)-MAX;
return (double)(dp[i&1^1][y]-dp[i&1^1][x]+sum[x]*sum[x]-sum[y]*sum[y])*1.00/(sum[x]-sum[y]);
}
void work(){
for(int i=1;i<=k;i++){
head=tail=1;
for(int j=1;j<=n;j++){
while(head<tail&&slope(i,que[head],que[head+1])<=sum[j])head++;
dp[i&1][j]=dp[i&1^1][que[head]]+sum[que[head]]*(sum[j]-sum[que[head]]);
next[i][j]=que[head];
while(head<tail&&slope(i,que[tail-1],que[tail])>=slope(i,que[tail],j))tail--;
que[++tail]=j;
}
}
printf("%lld\n",dp[k&1][n]);
for(int i=k,j=n;i>=1;i--){
j=next[i][j];
printf("%d ",j);
}
printf("\n");
}
void init(){
n=read();k=read();
sum[0]=0;
for(int i=1;i<=n;i++){
int x=read();
sum[i]=sum[i-1]+x;
}
}
int main(){
init();
work();
return 0;
}
```