奥法之劫 题解

· · 个人记录

这里是 O(n) 的做法。

我们先从一个朴素的 O(n^3) dp讲起。

f[i] 表示我处理到 i 号节点的最小代价和,那么转移有:

f[i]=\min(f[k]+\mathrm{cost}(k+1,i))

那么我们可以进行一步步优化到 O(n)

先关注 p_i\geqslant 0 的情况,有一些我们不得不拆的 a_i,我把它们称作“代价”。

这些代价不仅和当前 i 的位置是有关,还与 b 的高度有关。

我们发现代价是不好计算的。随着 i 的变化,代价也随之变化,所以每次都要 O(n) 扫描。

换一种思路,我们可以考虑每个 a_i 对答案的贡献,把它们挂在对应的节点上,就可以快速计算。

也就是说,我们要找到对于每一个 a_k,找到第一个 a_i,使得选中 a_i 时必须要拆除 a_k

具体而言,我们可以在 b 数组中 lower\_bound\ a_k,把 p_k 加到对应的高度上。

那么每次我想选一个 a_i,就必须要拆除所有挂在 b_j 上的代价,因为它们没有被选中且未被挡住。(其中 a_i=b_j

这样我们砍掉了计算区间贡献的 n,而换成了 \log n

再想想,由于 b 数组是单调的,所以我们可以一遍扫一遍存下来 lower\_bound 的结果,就可以做到 O(1) 了。

但此时,我们的转移点还是不确定,并且没有考虑 p_i<0 的情况。

对于 p_i<0 的部分,我们贪心的想肯定是越选多越好,所以除了把 i 选中的情况我们的答案都应该加上这些 p_i

没错,这里的影响就只和 i 的位置有关了,同样有关的还有我们的 f 数组。

其实我们 dp 的本质,就是在最小化这些东西,而之前的操作是为了方便计算必须要拆的代价。

我们设 a_i 对应的 b 的位置为 pos_ii 之前 p_i<0 部分的和为 sum_i,那么转移有:

f[i]=\min_{pos_k<pos_i}(f[k]-sum_k)+sum_i+cost_{j}

其中 cost_j 表示在 i 位置保留 b_j 高度的代价。

我们发现 min 里面的可以记一个前缀最小值,这样就可以 O(1) 转移了。

这样最小化了能够最小化的部分,也统计了代价。

最后注意几点:

  1. 为了保证最后答案能够统计到,要在 n+1 的位置插一个极大值。

  2. 如果存在 a_i\neq b_j 对于任意的 j ,那么它不能参与统计答案,但必须要参与代价以及 sum 的计算。

  3. 其实 fsum 数组根本没必要开。~

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
const int N=5e6+9;
const ll INF=1e18;
inline int read(){
    int res=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        res=(res<<1)+(res<<3)+ch-'0';
        ch=getchar();
    }
    return res*f;
}
int n,a[N],b[N],Q[N],m,pos[N];
ll f[N],g[N],cost[N],p[N],sum;
int main()
{
    // freopen("hs.in","r",stdin);
    // freopen("offa.out","w",stdout);
    n=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=1;i<=n;i++) p[i]=read();
    n++,a[n]=n;
    m=read();
    for(int i=1;i<=m;i++) b[i]=read();
    for(int i=1;i<=m;i++) pos[b[i]]=i;
    m++,b[m]=n,pos[n]=m;
    for(int i=1;i<N;i++) g[i]=INF;
    for (int i=1,j=1;i<=n;++i)
    {
        if(b[j]<i) j++;
        Q[i]=j;
    }
    ll tmp=0;
    for (int i=1;i<=n;i++)
    {
        int j=pos[a[i]];
        if(j)
        {
            if(g[j-1]>=INF) f[i]=INF;
            else f[i]=g[j-1]+cost[j]+sum;
        }
        if(p[i]>=0) cost[Q[a[i]]]+=p[i];
        else sum+=p[i];
        if(j&&f[i]<INF) g[j]=min(g[j],f[i]-sum);
    }
    if (f[n]<INF) printf("%lld",f[n]);
    else puts("Impossible");
    return 0;
}