CF1762F 题解

· · 题解

思路

发现若合法子序列的三个连续元素 x,y,z 满足 a_x\ge a_y,a_y\le a_z,那么一定有 \left| a_z-a_x\right|\le k,可以删掉中间的 a_ya_x\le a_y,a_y\ge a_z 的也一样。所以所有合法子序列均可通过删除操作变得单调不增或单调不降,因此只考虑单调序列即可。

显然单调不增和单调不降两种序列只会在 a_l=a_r 时重复计算,因此分别计算后减去所有的重复贡献即可。可以先计算单调不增的区间数,然后进行 a_i\leftarrow P-a_i+1 的转化(Pa_i 的值域),就把原来的单调不降变成了单调不增,再做一遍即可。

考虑怎么处理这个序列,发现对于 i 后面第一个 a_i\le a_j\le a_i+kj,如果直接接上这个数,对于后面继续选不小于 a_j 的数一定不劣。那么设 f_i 表示 l=i 时单调不增序列的结尾 r 的个数,可以直接先把 f_i 累加上 f_j

剩下的就只有 j 后面小于 a_j 的数,这些数不能接在 a_j 后面,需要另外统计。根据 j 是第一个合法数的定义,j 前面已经不可能有 [a_i,a_j] 内的数,所以需要另外统计的个数等价于 i 后面在 [a_i,a_j) 范围内数的个数。

现在需要求的是每个数后面的 j(代码中为 nxt_i),考虑按照 a 的值从大到小处理,用 set 维护目前在 [x,x+k] 范围内的数,求 nxt 时在 set 中二分即可,需要注意 a 值相等的位置要从后往前加入 set,才能在前面计算时找到后面相等的数。另外还需要维护每个数后面在 [a_i,a_{nxt_i}) 内的数,这个在倒序处理时维护权值树状数组即可实现。

另外还需要注意所有的清空都需要清空每个 a_i 的贡献,而不是对整个数组清空,否则时间复杂度是假的。

代码

#include<iostream>
#include<vector>
#include<set> 
#include<algorithm>
#define int long long
using namespace std;
const int N=5e5+10;
const int P=1e5;
int n,k,res,a[N],t[N],nxt[N],f[N];
set <int> pos;
vector <int> tp[N],tv; 
struct bit
{
    int b[N];
    int lowbit(int x){return x&(-x);}
    void add(int p,int x){for(;p<=P;p+=lowbit(p)) b[p]+=x;}
    int query(int p){int tr=0; for(;p;p-=lowbit(p)) tr+=b[p]; return tr;} 
}T;
void solv()
{
    for(int i=1;i<=n;i++)
    {
        if(tp[a[i]].empty()) tv.push_back(a[i]);
        tp[a[i]].push_back(i);
    }
    sort(tv.rbegin(),tv.rend()),pos.insert(n+1); int l=0;
    for(int x:tv)
    {
        while(tv[l]>x+k)
        {
            for(int p:tp[tv[l]]) pos.erase(p);
            l++; 
        }
        sort(tp[x].rbegin(),tp[x].rend());
        for(int p:tp[x]) nxt[p]=(*pos.lower_bound(p)),pos.insert(p);
    }
    for(int x:tv) tp[x].clear();
    tv.clear(),pos.clear();
    for(int i=n;i>=1;i--)
    {
        f[i]=1;
        if(nxt[i]!=n+1) f[i]=f[nxt[i]]+T.query(a[nxt[i]]-1)-T.query(a[i]-1)+1;
        T.add(a[i],1),res+=f[i];
    }
    for(int i=1;i<=n;i++) T.add(a[i],-1);
}
void sol()
{
    cin>>n>>k,res=0;
    for(int i=1;i<=n;i++) cin>>a[i],t[a[i]]++,res-=t[a[i]];
    solv();
    for(int i=1;i<=n;i++) t[a[i]]--,a[i]=P-a[i]+1;
    solv(),cout<<res<<'\n';
}
signed main()
{
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    int TT; cin>>TT;
    while(TT--) sol();
    return 0;
}