赤石

· · 个人记录

被szzjz骗去做这道题,给我赤石赤饱了。

给定一个长度为 n 的序列,q 次询问 l,r,k,表示在区间 l,r 中取出 k 个不相交的非空子段的元素和最大值。

首先答案对 k 是凸的,可以用费用流模型简单证明。只有一次询问时可以维护凸函数的 (\max,+) 卷积,每次合并两个复杂度为 O(n)

在有多个函数时无法每次将函数构造出来,不过考虑到只需要知道 f(k),可以使用 wqs二分 的方法简化计算。

考虑一条斜率为 k 的直线。我们只需要知道原函数与它的切点的坐标。那么我们可以对每个线段数区间的函数求得坐标后加起来即可。这样做复杂度为 O(n \log n \log n \log V)。使用整体二分即可将复杂度降至 O(n \log n \log V)

注意平台的处理。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define lc x << 1
#define rc x << 1 | 1
#define pll pair<ll,ll >
#define fir first
#define sec second
#define mk make_pair

const ll INF = 1e18;
const int MX = 35000 + 7,oo = 2e9;

vector<ll > add(vector<ll > A,vector<ll > B){
    //cerr << A.size() << ' ' << B.size() << '\n';
    vector<ll > dA,dB,dC;dA.resize(A.size() - 1);dB.resize(B.size() - 1);dC.resize(dA.size() + dB.size());
    for(int i = 0;i < dA.size();i++)dA[i] = A[i + 1] - A[i];
    for(int i = 0;i < dB.size();i++)dB[i] = B[i + 1] - B[i];
    merge(dA.begin(),dA.end(),dB.begin(),dB.end(),dC.begin(),greater<ll >());
    vector<ll > C;C.resize(A.size() + B.size() - 1);
    C[0] = A[0] + B[0];
    //cerr << C.size() << ' ' << dC.size() << ' ' << dA.size() << ' ' << dB.size() << '\n';
    for(int i = 0;i < dC.size();i++)C[i + 1] = C[i] + dC[i];
    return C;
}

vector<ll > max(vector<ll > A,vector<ll > B){
    vector<ll > C;
    C.resize(max(A.size(),B.size()));
    for(int i = 0;i < C.size();i++)
        if(i < A.size() && i < B.size())C[i] = max(A[i],B[i]);
        else if(i < A.size())C[i] = A[i];
        else C[i] = B[i];
    return C;
}

vector<ll > addint(vector<ll > A,int v){
    for(int i = 0;i < A.size();i++)A[i] += v;
    return A;
}

vector<ll > shift(vector<ll > A){
    vector<ll > B;B.resize(A.size() + 1);B[0] = -INF;
    for(int i = 0;i < A.size();i++)B[i + 1] = A[i];
    return B;
}

struct node{
    vector<ll > a[2][2];
    int pos[40][2][2] = {},sum;
}s[MX << 2];

struct mat{
    pll a[2][2];ll k,sum;
    mat operator * (const mat &b)const {
        mat ret;ret.a[0][0] = ret.a[0][1] = ret.a[1][0] = ret.a[1][1] = mk(-INF,-INF);ret.k = k;ret.sum = sum + b.sum;
        //cerr << "k = " << ret.k << '\n';
        for(int i = 0;i < 2;i++)
            for(int j = 0;j < 2;j++)
                for(int k = 0;k < 2;k++)
                    ret.a[i][k] = max(ret.a[i][k],mk(a[i][j].fir + b.a[j][k].fir - j * ret.k,a[i][j].sec + b.a[j][k].sec + j));
        ret.a[0][1] = max(ret.a[0][1],mk(a[0][1].fir + b.sum,a[0][1].sec));
        ret.a[1][0] = max(ret.a[1][0],mk(sum + b.a[1][0].fir,b.a[1][0].sec));
        ret.a[1][1] = max(ret.a[1][1],mk(a[1][1].fir + b.sum,a[1][1].sec));
        ret.a[1][1] = max(ret.a[1][1],mk(sum + b.a[1][1].fir,b.a[1][1].sec));
        ret.a[1][1] = max(ret.a[1][1],mk(sum + b.sum - k,1ll));
        //cerr << ret.a[0][0].fir << ' ' << ret.a[0][1].fir << ' ' << ret.a[1][0].fir << ' ' << ret.a[1][1].fir << '\n';
        return ret;
    }
};

int A[MX],n,m;

node merge(node x,node y,int lsum,int rsum){
    node ret;ret.a[0][0] = add(x.a[0][0],y.a[0][0]);ret.a[0][1] = add(x.a[0][0],y.a[0][1]);ret.a[1][0] = add(x.a[1][0],y.a[0][0]);ret.a[1][1] = add(x.a[1][0],y.a[0][1]);
    ret.a[0][0] = max(ret.a[0][0],shift(add(x.a[0][1],y.a[1][0])));
    ret.a[0][1] = max(ret.a[0][1],shift(add(x.a[0][1],y.a[1][1])));
    ret.a[0][1] = max(ret.a[0][1],addint(x.a[0][1],rsum));
    ret.a[1][0] = max(ret.a[1][0],shift(add(x.a[1][1],y.a[1][0])));
    ret.a[1][0] = max(ret.a[1][0],addint(y.a[1][0],lsum));
    ret.a[1][1] = max(ret.a[1][1],shift(add(x.a[1][1],y.a[1][1])));
    ret.a[1][1] = max(ret.a[1][1],addint(x.a[1][1],rsum));
    ret.a[1][1] = max(ret.a[1][1],addint(y.a[1][1],lsum));
    ret.a[1][1][0] = max(ret.a[1][1][0],1ll * lsum + rsum);
    return ret;
}

void build(int x,int l,int r){
    node &cur = s[x];
    if(l == r){
        cur.a[0][0].resize(2);cur.a[0][0][0] = 0;cur.a[0][0][1] = A[l];s[x].sum = A[l];
        cur.a[0][1].resize(1);cur.a[0][1][0] = A[l];
        cur.a[1][0].resize(1);cur.a[1][0][0] = A[l];
        cur.a[1][1].resize(1);cur.a[1][1][0] = A[l];
    }
    else{
        int mid = l + r >> 1;
        build(lc,l,mid);build(rc,mid + 1,r);s[x] = merge(s[lc],s[rc],s[lc].sum,s[rc].sum);s[x].sum = s[lc].sum + s[rc].sum;
        //cerr << l << ' ' << r << '\n';
        //for(auto it : s[x].a[0][0])cerr << it << ' ';
        //cerr << "\n\n";
    }
}

int l[MX],r[MX],c[MX];ll ans[MX];

void movepos(int x,int c0,int c1,ll k,int v){
    //cerr << "v = " << v << " pos = " << s[x].pos[v][c0][c1] << '\n';
    while(s[x].pos[v][c0][c1] + 1 < s[x].a[c0][c1].size() && k <= s[x].a[c0][c1][s[x].pos[v][c0][c1] + 1] - s[x].a[c0][c1][s[x].pos[v][c0][c1]])s[x].pos[v][c0][c1]++;
    //cerr << "pos = " << s[x].pos[v][c0][c1] << " size = " << s[x].a[c0][c1].size() << '\n';
}

mat query(int x,int l,int r,int ql,int qr,int v,ll k){
    if(ql <= l && r <= qr){
        movepos(x,0,0,k,v);movepos(x,0,1,k,v);movepos(x,1,0,k,v);movepos(x,1,1,k,v);mat ret;ret.k = k;ret.sum = s[x].sum;
        ret.a[0][0] = mk(s[x].a[0][0][s[x].pos[v][0][0]] - k * s[x].pos[v][0][0],s[x].pos[v][0][0]);
        ret.a[0][1] = mk(s[x].a[0][1][s[x].pos[v][0][1]] - k * s[x].pos[v][0][1],s[x].pos[v][0][1]);
        ret.a[1][0] = mk(s[x].a[1][0][s[x].pos[v][1][0]] - k * s[x].pos[v][1][0],s[x].pos[v][1][0]);
        ret.a[1][1] = mk(s[x].a[1][1][s[x].pos[v][1][1]] - k * s[x].pos[v][1][1],s[x].pos[v][1][1]);
        //cerr << l << ' ' << r << ' ' << ret.a[0][0].fir << ' ' << ret.a[0][1].fir << ' ' << ret.a[1][0].fir << ' ' << ret.a[1][1].fir << '\n';
        return ret;
    }
    int mid = l + r >> 1;
    if(ql <= mid && qr > mid)return query(lc,l,mid,ql,qr,v,k) * query(rc,mid + 1,r,ql,qr,v,k);
    else if(ql <= mid)return query(lc,l,mid,ql,qr,v,k);
    else return query(rc,mid + 1,r,ql,qr,v,k);
}

void solve(ll L,ll R,vector<int > q,int layer){
    if(q.empty())return;
    if(L == R){
        //cerr << "L = " << L << " R = " << R << '\n';
        for(auto it : q){
            mat x = query(1,1,n,l[it],r[it],layer,L);
            ans[it] = x.a[0][0].fir + L * c[it];
            //cerr << x.a[0][0].fir << ' ' << x.a[0][0].sec << ' ' << L * c[it] << '\n';
        }
        return;
    }
    ll MID = (L + R + 1) >> 1;vector<int > LL,RR;
    //cerr << L << ' ' << R << '\n';
    for(auto it : q){
        mat x = query(1,1,n,l[it],r[it],layer,MID);
        //cerr << "MID = " << MID << " val = " << x.a[0][0].fir << " num = " << x.a[0][0].sec << '\n';
        if(x.a[0][0].sec < c[it])LL.push_back(it);
        else if(x.a[0][0].sec > c[it])RR.push_back(it);
        else ans[it] = x.a[0][0].fir + MID * c[it];
    }
    solve(MID,R,RR,layer + 1);solve(L,MID - 1,LL,layer + 1);
}

vector<int > q;

int main(){
    cin >> n >> m;
    for(int i = 1;i <= n;i++)cin >> A[i];
    for(int i = 1;i <= m;i++)cin >> l[i] >> r[i] >> c[i];
    for(int i = 1;i <= m;i++)q.push_back(i);
    build(1,1,n);/*cerr << "st\n";*/solve(-oo,oo,q,0);
    for(int i = 1;i <= m;i++)cout << ans[i] << '\n';
    return 0;                                                
}//591 60 15 3