题解:P5665 [CSP-S2019] 划分

· · 题解

思路分析

首先,我们有一个结论:划分的段数越多越好。

如何证明呢?可以感性理解:若将 a_{1\sim i} 分成 k 段,则所有的方案中,最后一段小一点的会更优。而分的段数越多,最后一段就可以越小,所以我们会让分的段数尽可能的多。

如何求最小段数呢?一个直观的 dp 思路是,设 f_{i,j}1\sim i 划分,最后一段右端点为 i,倒数一段右端点为 j,直接转移即可,时间复杂度 O(n^3),可以拿到 36 的分数。

状态难以简化,不妨从贪心角度考虑,我们从前往后取。如果当前段的和大于我们上一个划分的段,则令答案加一,开始划分新的一段,反之则将一个新的元素放入当前段。这个贪心实际上是错的,可以 Hack 掉。

例如给定的数列是 \{5,5,1,2,8,9\},我们的贪心策略划分的结果是 \{5\},\{5\},\{1,2,8\},但实际上的最优策略是 \{5\},\{5,1,2\},\{8\},\{9\},我们思考贪心算法错误的原因。

倘若要将 \{5,5,1,2,8\} 这部分划分成三段,则上面的两种算法分别将第三段划为了 \{1,2,8\}\{8\},这两种方法都是合法的,但是第二种划分这一段的和更小,所以对于后面的转移更优。所以,我们划分时,不仅要让段数多的情况下合法,还要让最后一段的和尽可能小。

我们依据这个贪心思路重新设计 dp。设 f_i 表示将 1\sim i 合法划分的最大段数,g_i 表示将 1\sim i 划分成 f_i 的基础上,最小化最后一段的总和,同时记录序列 a 的前缀和 s,则可以得到 f 的转移式为 f_i=\max\{f_j+1\} 满足 j\in [0,i-1],g_j\le s_i-s_j。每求出 f_i 后,我们都求出 g_ig_i 就是所有满足所有满足 f_i=f_j+1s_i-s_j 的最小值。暴力跑是 O(n^2),可以得到 64 的分数。

如何优化转移的过程呢?容易发现一个特点,f_i 都是单调不降的。因为无论如何只需要将 a_i 加入最后一段,f_i 就可以继承 f_{i-1} 的值,而 a_i>0,所以 s 也具有单调性。因为我们要找到最大的 s_i-s_j,所以我们只需要找到最后一个 j 满足 s_i-s_j\ge g_j 即可最大化 f_i 的同时,最小化 g_i

而对于 s_i-s_j\ge g_j 移项可得 s_i\ge s_j+g_j。对于 k>j,若有 s_k+g_k\le s_j+g_j,则 k 一定比 j 优,则维护一个 s_i+g_i 单调的栈,每次计算 f_i 时则找到最后一个 j,容易发现这个最优决策 j 实际上也是单调的,所以可以换成单调队列。

最后统计方案时,只需记录每个点的最优决策 p_i,然后回溯计算即可,时间复杂度 O(n)

下面给出几点卡常建议:

  1. s_i-s_{p_i} 代替 g_if_i 也可以省去。

  2. 我们对于 n\le 4\times 10^7,实际上不需要开 a,b,s 三个数组,直接用同一个数组生成 a,b,计算前缀和 s

Code

#include <iostream>
#include <cstdio>
using namespace std;
typedef __int128 i128;
typedef long long ll;
const int N=4e7+5;
const int M=1e5+10;
const ll mod=(1<<30);
int f[N],p[N],q[N];
ll s[N];
int n,type,m;
i128 ans=0;
inline ll g(int x){
    return s[x]-s[p[x]];
}
void init(){
    scanf("%d %d",&n,&type);
    if(type==0){
        for(int i=1;i<=n;i++)scanf("%lld",s+i);
    }else{
        long long x,y,z,pp,l,r,lstp=0;
        scanf("%lld %lld %lld %lld %lld %d",&x,&y,&z,s+1,s+2,&m);
        for(int i=3;i<=n;i++)s[i]=(x*s[i-1]+y*s[i-2]+z)%mod;
        for(int i=1;i<=m;i++){
            scanf("%lld %lld %lld",&pp,&l,&r);
            for(int j=lstp+1;j<=pp;j++)s[j]=(s[j]%(r-l+1))+l;
            lstp=pp;
        }
    } 
    for(int i=1;i<=n;i++)s[i]+=s[i-1];
    return;
}
void write(i128 x){
    if(x>9)write(x/10);
    putchar(x%10+'0');
};
int main(){
    init();
    int l=1,r=1;q[1]=0;
    for(int i=1;i<=n;i++){
        while(l<r&&s[i]>=g(q[l+1])+s[q[l+1]])l++;
        p[i]=q[l],f[i]=f[p[i]]+1;
        while(l<=r&&g(q[r])+s[q[r]]>=g(i)+s[i])r--;
        q[++r]=i; 
    }
    for(;n;n=p[n])ans+=(i128)(s[n]-s[p[n]])*(s[n]-s[p[n]]);
    write(ans);
    return 0;
}

如有错误,请指出。