题解:P12646 [KOI 2024 Round 1] 升序

· · 题解

A 掉这道题才发现自己是第 9 个 A 的,抓紧来发篇题解。

题目传送门

[KOI 2024 Round 1] 升序

题意理解

一看到这道题,我发现其实就是 [KOI 2024 Round 1] 加倍 的加强版,多了一个区间查询。

题意:给你一个长度为 M 的正整数序列 X_1,X_2, \dots,X_M,给你 l,r,每次操作可以把 X_i2,求出让 X_l, X_{l+1}, \dots, X_r 变为升序的最小操作次数。

解题思路

简化一下,对于 \forall i,X_i \leq X_{i+1} \times 2^{c_i},我们可以把每个 c_i 存下来,如果 X_i 乘了 1 次,那么后面的 X_j 都会多乘 1 次,所以前面的会对后面的产生影响,那么就要计算前缀和,它们的前缀和加起来就是答案————————吗?

举个例子,序列 3,1,5,1,5 ,其对应的 c_i 分别为 2,0,3,0,前缀和就是 2,2,5,5,结果为 14,很明显是不对的。

思考一下,发现需要计算每个 X_i 可以抵消掉后面的次数,那么这个 c_i 就分别是 2,-2,3,-2,前缀和加起来就是 6,这才是正确的。

But,序列 1,10,5c_i-3,1,但是前缀和加起来就是 -5,很明显的错误,所以如果前缀和变成了负数,要令其变为 0

综上,我们可以归纳出,c_iX_iX_{i+1} 之间的变化,前缀和 S_i = \max(0, S_{i-1}+c_i),对于区间 l,r,答案就是 \sum_{i=l}^r S_i,如果每次跑一遍累加一下,复杂度是 O(MQ),只能得到 12pts。

考虑优化,将 \sum_{i=l}^r S_i 拆开,得到 \sum^r_{i=l} \sum^i_{j=l} c_j,交换两个求和符号得 \sum^r_{j=l} \sum^r_{i=j} c_j,再化简一下是 \sum^r_{j=l}(r-j+1)c_j,将 (r+1) 提出来得 (r+1)\sum^r_{j=l} c_j-\sum^r_{j=l}j\times c_j,所以这就是答案的式子,可以用两个前缀和维护,就可以实现 O(1) 查询,但是由于 S_i = \max(0, S_{i-1}+c_i),设 R_i=R_{i-1}+c_i,所以这个式子只能用于 R_i ≥ 0 的区间。

对于 R_i < 0 的情况,可以考虑分段,将连续的 R_i(R_i ≥ 0) 分为一段,由于这一段 R_i ≥ 0,所以就可以用 (r+1)\sum^r_{j=l} c_j-\sum^r_{j=l}j\times c_j 算出每一段的答案,再拼成 l,r 的区间。

如何实现呢?我们知道,如果 R_j<R_i (j>i),那么这一段就是 <0 的,所以需要找到第一个比 R_i 小的 R_j,可以用单调栈维护,存入 nxt_i 表示第一个比 R_i 小的 R_jj

代码实现

每次出现负数,原数组该位置的值至少 \div 2,所以分段个数为 log W

复杂度大约是 O(n+QlogV),注意开 long long。

#include<bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
const int N=3e5+10;
int n,q,a[N],c[N],sum,cnt,s[N],ss[N],nxt[N];
stack<int> st;
int getsum(int l,int r){
    return (r+1)*(s[r]-s[l-1])-(ss[r]-ss[l-1]);
}
void zj(){
    //预处理nxt[i]
    for(int i=n;i>=0;i--){
        while(!st.empty()&&s[st.top()]>=s[i]) st.pop();
        if(!st.empty()) nxt[i]=st.top();
        else nxt[i]=n+1;
        st.push(i);
    }
    while(q--){
        int l,r,ans=0;
        cin>>l>>r;
        if(l==r){
            cout<<0<<endl;
            continue;
        }
        int i=l-1;
        while(i<r){
            ans+=getsum(i+1,min(nxt[i]-1,r-1));//计算每一段的和
            i=nxt[i];//跳到下一段
        }
        cout<<ans<<endl;
    }
}
signed main(){
    cin.tie(0)->ios::sync_with_stdio(0);
    cin>>n>>q;
    for(int i=1;i<=n;i++) cin>>a[i];
    //计算 c[i]
    for(int i=1;i<n;i++){
        if(a[i+1]>a[i]){
            int x=a[i];
            while(x*2<=a[i+1]){
                x*=2;
                c[i]--;
            }
        }else{
            int x=a[i+1];
            while(x<a[i]){
                x*=2;
                c[i]++;
            }
        }
    }
    for(int i=1;i<=n;i++) s[i]=s[i-1]+c[i];//s[i]就是题解中的R[i]
    for(int i=1;i<=n;i++) ss[i]=ss[i-1]+c[i]*i;
    zj();
    return 0;
}

写在最后

感谢 sheryang 提供的思路,谢谢观看,管理大大辛苦啦!