题解:P12930 [USACO4.3] 逢低吸纳 Buy Low, Buy Lower 加强版

· · 题解

简单求 \text{LIS} 的题目。

思路

f[i] 表示以第 i 天的股价 a[i] 结尾的最长递减子序列长度(实际并不需要维护 f[i]),cnt[i] 表示达到该长度的不同价格序列的个数。转移时,从所有 j<ia[j]>a[i] 的状态中选取 f[j] 最大的那些,则 f[i]=\max\{f[j]\}+1,并且

cnt[i]=\sum_{\substack{j<i\\a[j]>a[i]\\f[j]=f[i]-1}}cnt[j]

如果不存在这样的 j,则 f[i]=1,cnt[i]=1

问题的关键是快速查询“所有大于当前股价的历史状态中的最优值”。因为股价范围大但不同值最多 n 个,先离散化。设排序去重后共有 m 个不同价格,每个股价映射到排名 1\ldots m(排名越大股价越高)。对于当前股价排名 id,需查询区间 [id+1,m]

用线段树维护每个排名上的最优信息。每个节点保存一对 (\text{len},\text{cnt}),合并规则:

一次区间查询即可得到最优情况。

代码

码风不咋地,一堆 \text{define} 主要是为了简短,希望读者不要抄袭。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll M=1e9+7,N=1e6+9;
#define pll pair<ll,ll>
#define fi first
#define se second
#define B begin
#define E end
ll n,a[N],sz;
vector<ll> v,sl,sc;
pll merge(pll x,pll y){
    if(x.fi>y.fi)return x;
    if(y.fi>x.fi)return y;
    return {x.fi,(x.se+y.se)%M};
}
void init(ll m){
    sz=1;
    while(sz<m)sz<<=1;
    sl.assign(2*sz,0),sc.assign(2*sz,0);
}
void upd(ll pos,ll len,ll cnt){
    ll p=pos+sz-1;
    sl[p]=len,sc[p]=cnt,p>>=1;
    while(p){
        ll l=p<<1,r=l|1;
        if(sl[l]>sl[r])sl[p]=sl[l],sc[p]=sc[l];
        else if(sl[r]>sl[l])sl[p]=sl[r],sc[p]=sc[r];
        else sl[p]=sl[l],sc[p]=(sc[l]+sc[r])%M;
        p>>=1;
    }
}
pll query(ll l,ll r){
    if(l>r)return {0,0};
    l+=sz-1,r+=sz-1;
    pll res={0,0};
    while(l<=r){
        if(l&1)res=merge(res,{sl[l],sc[l]}),l++;
        if(!(r&1))res=merge(res,{sl[r],sc[r]}),r--;
        l>>=1,r>>=1;
    }
    return res;
}
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;
}
inline ll read(){
    ll x=0,f=1;
    char ch=getc();
    while(ch<'0'||ch>'9'){
        if(ch=='-')f=-1;
        ch=getc();
    }
    while(ch>='0'&&ch<='9'){
        x=x*10+(ch-'0');
        ch=getc();
    }
    return x*f;
}
int main(){
    n=read();
    for(ll i=1;i<=n;i++)a[i]=read(),v.push_back(a[i]);
    sort(v.B(),v.E()),v.erase(unique(v.B(),v.E()),v.E());
    ll m=v.size();
    init(m);
    vector<ll> bl(m+1,0),bc(m+1,0);
    for(ll i=1;i<=n;i++){
        ll id=lower_bound(v.B(),v.E(),a[i])-v.B()+1;
        pll res=query(id+1,m);
        ll li,ci;
        if(res.fi==0)li=1,ci=1;
        else li=res.fi+1,ci=res.se;
        ll ol=bl[id],oc=bc[id];
        if(li>ol)ol=li,oc=ci,upd(id,li,ci);
        else if(li==ol)if(ci>oc)oc=ci,upd(id,li,ci);
    }
    pll ans=query(1,m);
    cout<<ans.fi<<" "<<ans.se%M;
    return 0;
}

这样就能在 O(n\log n) 时间内解决本题。