斜率优化入门

· · 个人记录

斜率优化,是用于解决dp中的某些情况,将其优化

由于太菜了,只会有决策单调性的dp

对于一个dp,以P2120为例

f[i]=min(f[i],f[j]+x[i]*(sp[i]-sp[j])-(sxp[i]-sxp[j]));

展开:

fi=fj+xispi-xispj-sxpi+sxpj+ci

整理:

fj+sxpj=xispj - xispi+fi+sxpi-ci

y = k*x - b

把只与j有关的移到左边为y,把只与i有关的移到右边为x;都有关的也在右边,成为a[i]*b[j]等形式,将与i有关的作k,与j有关的作x

然后用一个单调栈维护一个下凸壳

维护:左端斜率小于k,过期弹出;右端不满足下凸壳,弹出

原理证明:

注意:以下均以求bmin为例

对于斜率优化,我们在扫i的时候,我们确定了斜率,b的大部(除fi外),(x,y)均不确定

所以我们可以把问题转为,已知斜率,和一堆点,求该直线过哪一个点的时候截距最大(小)

首先可以证明,这个点一定在某个凸壳上;

如图,我们以bmin为例当点在下凸壳上方时(红),不如在凸壳上优(绿),在凸壳下方时(紫),凸壳不合法(应该变成紫的)

其他同理

当k具有单调性时,我们可以用单调队列维护

如图,k1<k2<k3 当k越来越大时,前面的某些点就永远不可能有机会被扫到,就可以弹出单调队列

具体来说,如果队首两个点,他们组成的斜率<k,那么就弹出

while(r>l&&y(q[l+1])-y(q[l])<=k(i)*(X(q[l+1])-X(q[l]))) l++;
//左端斜率小于k,过期弹出

当然,我们要满足下凸壳,所以有:

while(r>l&&(y(i)-y(q[r]))*(X(q[r])-X(q[r-1]))<=(y(q[r])-y(q[r-1]))*(X(i)-X(q[r]))) r--;
//右端不满足下凸壳,弹出

代数证明:

我们有方程 f[i]=max(f[j]+(si-sj)^2)

若要从j转移比从k转移优,则有

fj+si^si-2sisj+sj^2>fk+si^si-2sisk+sk^2

(fj+sj^2-(fk+sk^2))/(sj-sk)>2si

我们认为fj+sj^2为yj,sj为x,2si为k

那么我们可以看出,只有当j,k之间斜率大于2si时,j优,否则k优

所以当队首两个点的斜率小于k时,队首一定不如后面的点优,所以弹出队首

由图发现,如果为上凸壳,中间的点不可能比两边的优,直接删掉中间的,把两边的连起来就行了

至此,k有单调性的完毕

那么如果k没有单调性,那么我们就用一个单调栈维护下凸壳,有了下凸壳之后用二分查找最符合k的:

int ef(int v) //注意单调栈中,k是单调增的 
{
    int L=l,R=r-1,ans=-1;
    while(L<=R)
    {
        int mid=(L+R)>>1;
        int x=q[mid],y=q[mid+1];
        if((sb[y]-sb[x])*v>f[y]-f[x]+s*sb[x]-s*sb[y]) L=mid+1; //如果v>K,往大的搜 
        //v>((f[y]+s*sb[y])-(f[x]+s*sb[x]))/(sb[y]-sb[x])
        //   K  =           Y             /    X    
        else ans=mid,R=mid-1; //否则往小的搜,并记为合法 
    }
    if(ans==-1) return q[r];
    return q[ans];
}

只有在k有单调性时可以用单调队列

while(r>l&&y(q[l+1])-y(q[l])<=k(i)*(X(q[l+1])-X(q[l]))) l++;
//左端斜率小于k,过期弹出
while(r>l&&(y(i)-y(q[r]))*(X(q[r])-X(q[r-1]))<=(y(q[r])-y(q[r-1]))*(X(i)-X(q[r]))) r--;
//右端不满足下凸壳,弹出

斜率优化代码:

    for(int i=1;i<=n;i++)
    {
        while(r>l&&y(q[l+1])-y(q[l])<=k(i)*(X(q[l+1])-X(q[l]))) l++;
        int j=q[l];
        f[i]=f[j]+x[i]*sp[i]-x[i]*sp[j]-sxp[i]+sxp[j]+c[i];
        while(r>l&&(y(i)-y(q[r]))*(X(q[r])-X(q[r-1]))<=(y(q[r])-y(q[r-1]))*(X(i)-X(q[r]))) r--;
        q[++r]=i;
    }

完整代码

#include <bits/stdc++.h>
#define ll long long
#define il inline
using namespace std;
const int N = 1e6+6;
int n,x[N];
ll sp[N],sxp[N],p[N],c[N],f[N];
il ll k(int i){return x[i];}
il ll X(int i){return sp[i];}
il ll y(int i){return f[i]+sxp[i];}
il ll b(int i){return -x[i]*sp[i]+f[i]+sxp[i]-c[i];}
int q[N],l,r;
int main()
{
    cin>>n;
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d%d",x+i,p+i,c+i);
        sp[i]=sp[i-1]+p[i];
        sxp[i]=sxp[i-1]+x[i]*p[i];
    }
    f[0]=0;
    for(int i=1;i<=n;i++)
    {
        while(r>l&&y(q[l+1])-y(q[l])<=k(i)*(X(q[l+1])-X(q[l]))) l++;
        int j=q[l];
        f[i]=f[j]+x[i]*sp[i]-x[i]*sp[j]-sxp[i]+sxp[j]+c[i];
        while(r>l&&(y(i)-y(q[r]))*(X(q[r])-X(q[r-1]))<=(y(q[r])-y(q[r-1]))*(X(i)-X(q[r]))) r--;
        q[++r]=i;
    }
    printf("%lld\n",f[n]);
    return 0;
}

对于这种dp单调性的,我们还有一种玄学做法:

第二层,记录上一次的转移,从上一次转移开始扫

在随机数下复杂度十分优秀

例:P3628

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e6+6;
ll read()
{
    ll s=0,f=1;
    char c=getchar();
    while(c<48||c>57) {if(c==-1) f=-1;c=getchar();};
    while(47<c&&c<58) s=10*s+c-48,c=getchar();
    return s;
}
int n,a,b,c,d[N];
ll s[N],f[N];

int main()
{
//  freopen("1.in","r",stdin);
    cin>>n>>a>>b>>c;
    for(int i=1;i<=n;i++)
    {
        scanf("%d",d+i);
        s[i]=s[i-1]+d[i];
    }
    memset(f,143,sizeof(f));
    f[0]=0;
    int lst=0;
    for(int i=1;i<=n;i++)
    {
        for(int j=lst;j<i;j++)
        {
            ll x=s[i]-s[j];
            ll sm=f[j]+1ll*a*x*x+b*x+c;
            if(sm>f[i])
            {
                f[i]=sm;
                lst=j;
            }
        }
    }
    cout<<f[n];
    return 0;
}

关于此操作,目前还不会证明或者卡,有兴趣的可以看讨论

接下来是dp没有单调性的情况 (更准确的说是k不单调

例:P5785

根据dp方程

f[i]=f[j]+s(sb[n]-sb[j])+sa[i](sb[i]-sb[j]);

由于题目里没有保证t为正,使前缀和数组不一定单调,k不一定单调,但是这个凸壳一定存在,所以我们可以二分斜率来找这个点

int ef(int v) //注意单调栈中,k是单调增的 
{
    int L=l,R=r-1,ans=-1;
    while(L<=R)
    {
        int mid=(L+R)>>1;
        int x=q[mid],y=q[mid+1];
        if((sb[y]-sb[x])*v>f[y]-f[x]+s*sb[x]-s*sb[y]) L=mid+1; //如果v>K,往大的搜 
        //v>((f[y]+s*sb[y])-(f[x]+s*sb[x]))/(sb[y]-sb[x])
        //   K  =           Y             /    X    
        else ans=mid,R=mid-1; //否则往小的搜,并记为合法 
    }
    if(ans==-1) return q[r];
    return q[ans];
}
#include<bits/stdc++.h>
#define int long long
#define il inline
using namespace std;
const int N = 3e5+15;
int n,s,sa[N],sb[N],f[N];
int q[N],l,r;
int ef(int v)
{
    int L=l,R=r-1,ans=-1;
    while(L<=R)
    {
        int mid=(L+R)>>1;
        int x=q[mid],y=q[mid+1];
        if((sb[y]-sb[x])*v>f[y]-f[x]+s*sb[x]-s*sb[y]) L=mid+1;
        else ans=mid,R=mid-1; 
    }
    if(ans==-1) return q[r];
    return q[ans];
}
il int y(int i){return f[i]-s*sb[i];}
il int k(int i){return sa[i];}
il int x(int i){return sb[i];}
il int b(int i){return f[i]-sa[i]*sb[i]+s*sb[n];}
main()
{
    cin>>n>>s;
    for(int i=1;i<=n;i++)
    {
        int a,b;
        scanf("%lld%lld",&a,&b);
        sa[i]=sa[i-1]+a;
        sb[i]=sb[i-1]+b;
//      f[i]=1e9;
    }
    for(int i=1;i<=n;i++)
    {
        int j=ef(sa[i]);
        f[i]=f[j]+s*(sb[n]-sb[j])+sa[i]*(sb[i]-sb[j]);
        while(r>l&&(y(i)-y(q[r]))*(x(q[r])-x(q[r-1]))<=(y(q[r])-y(q[r-1]))*(x(i)-x(q[r]))) r--;
        q[++r]=i;
    }
    printf("%lld\n",f[n]);
}

如果dp时二维(伪)的

例:P4072

dp方程:

f[i][k]=min(f[i][k],f[j][k-1]+mxx-2s[n]x);

展开:

f[j][k-1]+2snsj+msjsj=2msisj+f[i][k]+2snsi-msisi

注意到:我们实际上只用在j中取min,并不用在k中取min,k可以认为只是一个参数,甚至可以被化掉。

所以,我们把k作为一个参数,去搞在k下的单调队列优化

il ll X(int i){return s[i];}
il ll K(int i){return 2*m*s[i];}
il ll Y(int i,int j){return f[i][j-1]+2*s[n]*s[i]+m*s[i]*s[i];}
il ll B(int i,int j){return f[i][j]+2*s[n]*s[i]-m*s[i]*s[i];}
    memset(f,127,sizeof(f));f[0][0]=0; //注意到可能出现k>i的非法情况,所以初始化,保证它无法转移 
    for(int k=1;k<=m;k++)
    {
        l=r=0; //初始单调队列
        for(int i=1;i<=n;i++) //带k为参数,照常斜优
        {
            while(r>l&&Y(q[l+1],k)-Y(q[l],k)<K(i)*(X(q[l+1])-X(q[l]))) l++;
            int j=q[l];
            ll x=s[i]-s[j];
            f[i][k]=f[j][k-1]+m*x*x-2*s[n]*x;
            while(r>l&&(Y(i,k)-Y(q[r],k))*(X(q[r])-X(q[r-1]))<=(Y(q[r],k)-Y(q[r-1],k))*(X(i)-X(q[r]))) r--;
            q[++r]=i;
    //      printf("%d %d %d %lld %lld\n",i,k,j,f[j][k-1],f[i][k]);
        }   
    }
#include <bits/stdc++.h>
#define ll long long
#define int long long
#define il inline
using namespace std;
const int N = 3333;
int n,m,a[N];
ll s[N],f[N][N];
il ll X(int i){return s[i];}
il ll K(int i){return 2*m*s[i];}
il ll Y(int i,int j){return f[i][j-1]+2*s[n]*s[i]+m*s[i]*s[i];}
il ll B(int i,int j){return f[i][j]+2*s[n]*s[i]-m*s[i]*s[i];}
int q[N],l,r;
main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",a+i);
        s[i]=s[i-1]+a[i];
    }
    memset(f,127,sizeof(f));f[0][0]=0; //注意到可能出现k>i的非法情况,所以初始化,保证它无法转移 
    for(int k=1;k<=m;k++)
    {
        l=r=0;
        for(int i=1;i<=n;i++)
        {
            while(r>l&&Y(q[l+1],k)-Y(q[l],k)<K(i)*(X(q[l+1])-X(q[l]))) l++;
            int j=q[l];
            ll x=s[i]-s[j];
            f[i][k]=f[j][k-1]+m*x*x-2*s[n]*x;
            while(r>l&&(Y(i,k)-Y(q[r],k))*(X(q[r])-X(q[r-1]))<=(Y(q[r],k)-Y(q[r-1],k))*(X(i)-X(q[r]))) r--;
            q[++r]=i;
    //      printf("%d %d %d %lld %lld\n",i,k,j,f[j][k-1],f[i][k]);
        }   
    }
    printf("%lld\n",f[n][m]+s[n]*s[n]);
    return 0;
}

如果说复杂度较为宽松,而又有两个关于ij的式子

可以试着枚举一维,另一维斜优

例:P4056

注意到这题复杂度较为宽松,nm可以卡过去而dp方程为

f[i]=max(f[j]-(xi-xj)^2-(yi-yj)^2)+vi

展开:fi=fj+vi-xixi+2xixj-xjxj-yiyi+2yiyj-yjyj

注意到有-2xixj-2yiyj两项,无法直接斜优

所以我们把y这一维给固定下来:

我们定义 q[i][j]:在i行的单调队列,li,ri为队首队尾

    sort(p+1,p+1+n,cmp); //x优先升序,y次级优先升序 
    memset(f,143,sizeof(f));
//  f[0]=0;
    f[1]=0;
    for(int i=1;i<=n;i++)
    {
        int t=p[i].y;    //找到这一行 
        for(int j=1;j<=t;j++) //根据规则,只能从在他前面的转移过来 
        {
            int c=sqr(t-j);  //确定一项,对另一项斜优 
            while(r[j]>l[j]&&y(q[j][l[j]+1])-y(q[j][l[j]])<=k(i)*(x(q[j][l[j]+1])-x(q[j][l[j]]))) l[j]++;
            int jj=q[j][l[j]];
            f[i]=max(f[i],f[jj]-c-sqr(p[i].x-p[jj].x)); //由于是多列最优求更优,所以要有max 
    //      printf("%d %d %d %d\n",j,l[j],jj,f[jj]-c-sqr(p[i].x-p[jj].x));
        }
        while(r[t]>l[t]&&(y(i)-y(q[t][r[t]]))*(x(q[t][r[t]])-x(q[t][r[t]-1]))<=(y(q[t][r[t]])-y(q[t][r[t]-1]))*(x(i)-x(q[t][r[t]]))) r[t]--;
        q[t][++r[t]]=i; //把i加入到第yi行的单调队列中 
        f[i]+=p[i].v;
    //  printf("# %d %d    ####\n",i,f[i]);
    }
#include <bits/stdc++.h>
#define il inline
using namespace std;
const int N = 2e5+10;
const int M = 1e3+3;
int n,m;
struct bear
{
    int x,y,v;
}p[N];
bool cmp(bear a,bear b)
{
    if(a.x==b.x) return a.y<b.y;
    return a.x<b.x;
}
il int sqr(int x){return x*x;}
il int dis(bear a,bear b)
{
    return sqr(a.x-b.x)+sqr(a.y-b.y);
}
int f[N],q[M][M],l[M],r[M];
il int y(int i){return -f[i]+sqr(p[i].x);}
il int k(int i){return 2*p[i].x;}
il int x(int i){return p[i].x;}
int main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d%d",&p[i].x,&p[i].y,&p[i].v);
    }
    sort(p+1,p+1+n,cmp); //x优先升序,y次级优先升序 
    memset(f,143,sizeof(f));
//  f[0]=0;
    f[1]=0;
    for(int i=1;i<=n;i++)
    {
        int t=p[i].y;    //找到这一行 
        for(int j=1;j<=t;j++) //根据规则,只能从在他前面的转移过来 
        {
            int c=sqr(t-j);  //确定一项,对另一项斜优 
            while(r[j]>l[j]&&y(q[j][l[j]+1])-y(q[j][l[j]])<=k(i)*(x(q[j][l[j]+1])-x(q[j][l[j]]))) l[j]++;
            int jj=q[j][l[j]];
            f[i]=max(f[i],f[jj]-c-sqr(p[i].x-p[jj].x)); //由于是多列最优求更优,所以要有max 
    //      printf("%d %d %d %d\n",j,l[j],jj,f[jj]-c-sqr(p[i].x-p[jj].x));
        }
        while(r[t]>l[t]&&(y(i)-y(q[t][r[t]]))*(x(q[t][r[t]])-x(q[t][r[t]-1]))<=(y(q[t][r[t]])-y(q[t][r[t]-1]))*(x(i)-x(q[t][r[t]]))) r[t]--;
        q[t][++r[t]]=i; //把i加入到第yi行的单调队列中 
        f[i]+=p[i].v;
    //  printf("# %d %d    ####\n",i,f[i]);
    }
    cout<<f[n];
    return 0;
}