题解:CF660F Bear and Bowling 4

· · 题解

前言

直接数据结构代替思考了。

思路

首先因为是连续的子段所以不难想到分治,那么现在变成了你如何把两段拼起来,我们假设 l\sim mid 里面取了一段长度为 len 贡献为 s 的,然后 mid+1\sim r 中选一段总和为 d 的,我们考虑先按照第一个位置为 1 的方式选然后考虑和前一段拼上之后的贡献的增量,我们发现这是一个整体系数加 len 的过程,即贡献增加 len\times d 然后我们把这个式子列出来就成了 s+len\times d+v 其中的 v 是右边这一段当作起始点为 1 来算的贡献,有了这个东西就不难想到李超线段树了,直接每次分治把右部分的线段加入李超然后左部分在李超线段树上求最值即可,复杂度大概是 O(n\log^2{n}) 但实际上跑得非常快因为李超基本跑不满。

代码

#include <bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define rep(i,x,y) for(register int i=x;i<=y;i++)
#define rep1(i,x,y) for(register int i=x;i>=y;--i)
#define int long long
#define fire signed
#define il inline
template<class T> il void print(T x) {
    if(x<0) printf("-"),x=-x;
    if (x > 9) print(x / 10);
    putchar(x % 10 + '0');
}
template<class T> il void in(T &x) {
    x = 0; char ch = getchar();
    int f = 1;
    while (ch < '0' || ch > '9') {if(ch=='-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
    x *= f;
}
int T=1;
const int N=2e5+10;
struct infom{
    int k,b;
    int id;
};
int calc(infom a,int x) {
    return a.k*x+a.b;
}
int idx;
struct node{
    int l,r;
    infom mx;
}tr[N*20];
int rt;
void modify(int &u,int l,int r,infom x) {
    if(!u) {
        u=++idx;
        tr[u].l=tr[u].r=false;
        tr[u].mx.id=false;
    }
    if(!tr[u].mx.id) {
        tr[u].mx=x;
        return;
    }
    int mid=l+r>>1;
    int l1=calc(x,l),r1=calc(x,r);
    int l2=calc(tr[u].mx,l),r2=calc(tr[u].mx,r);
    if(l1>l2&&r1>r2) {
        tr[u].mx=x;
        return;
    }else if(l1>l2||r1>r2){
        int mid=l+r>>1;
        int mid1=calc(x,mid),mid2=calc(tr[u].mx,mid);
        if(mid1>mid2) {
            swap(x,tr[u].mx);
        }
        l1=calc(x,l),r1=calc(x,r);
        l2=calc(tr[u].mx,l),r2=calc(tr[u].mx,r);
        if(l1>l2) modify(tr[u].l,l,mid,x);
        else modify(tr[u].r,mid+1,r,x);
    }
}
int Ans(int u,int l,int r,int k) {
    if(!u||!tr[u].mx.id) return 0;
    int res=calc(tr[u].mx,k);
    if(l==r) {
        return res;
    }
    int mid=l+r>>1;
    if(mid>=k) res=max(res,Ans(tr[u].l,l,mid,k));
    else res=max(res,Ans(tr[u].r,mid+1,r,k));
    return res;
}
int ans=0;
int a[N],n;
void get(int l,int r) {
    if(l==r) {
        ans=max(ans,a[l]);
        return;
    }
    int mid=l+r>>1;
    get(l,mid);
    get(mid+1,r);
    idx=0;
    rt=0;
    int sum=0,s1=false;
    rep(i,mid+1,r) {
        s1+=a[i];
        sum+=a[i]*(i-mid);
        infom now={s1,sum,1};
        ans=max(ans,sum);
        modify(rt,1,n,now);
    }
    int s2=0,ss=0;
    rep1(i,mid,l) {
        s2+=a[i]*(mid-i);
        ss+=a[i];
        int dis=ss*(mid-i+1)-s2;
        ans=max(ans,dis+Ans(rt,1,n,mid-i+1));
    }
}
void solve() {
    in(n);
    rep(i,1,n) in(a[i]);
    get(1,n);
    printf("%lld\n",ans);
}
fire main() {
    while(T--) {
        solve();
    }
    return false;
}