P11390题解

· · 题解

线段树典题,一开始写假了写篇题解记录一下。

容易发现是扫描线统计全局左端点布尔值和。

直接暴力更新前 k 个前驱就行。

发现根本不需要容斥。

线段树的每一节点,记 res[s] 为在当前节点被 sk 覆盖的点集,注意这里是准确值,不统计 s 有包含关系的,对于 0 也要统计到。

ts 为当前节点的 tag 的状压表示,所以将所有的 s|ts 累计到统计数组里就行。标记永久化,每次打 tag 或者撤销 tag 的时候就能找到原 tag

最后扫一遍就行,最大点 717ms

#include<bits/stdc++.h>
using namespace std;
typedef long long ll1;
template<typename T>
void in(T &n){
    n=0;char c=getchar();bool flag=0;
    for(;c<'0'||c>'9';c=getchar()) if (c=='-') flag=1;
    for(;c>='0'&&c<='9';c=getchar()) (n*=10)+=(c^48);
    if (flag) n=-n;
}
const int N=1e5+5;
const int mxs=16;
int n,K;
int a[N],lst[N],mp[N];
/*
静态地求出全局的存在出现1-k次的数的子区间的个数
//首先可以典上扫描线
//每次扫到新的数的时候
//发现当前数的cnt++
//然后找到当前数的前4个位置(如不存在不访问)
//把[lst[i]+1),i]的k加上1,相对应地改变
//动态维护4个k值>0的点数即可
*/
struct Node{
    int l,r,len;
    int res[16];
    int cnt[4],s;
}t[N*8];
#define ls (o<<1)
#define rs (ls|1)
#define mid ((l+r)>>1)
void build(int o,int l,int r){
    t[o]=(Node){l,r,r-l+1};t[o].s=0;
    memset(t[o].cnt,0,sizeof(t[o].cnt));
    memset(t[o].res,0,sizeof(t[o].res));
    t[o].res[0]=r-l+1;
    if(l==r)return;
    build(ls,l,mid),build(rs,mid+1,r);
}
int get_s(int o){
    int res=0;
    for(int i=0;i<4;i++)if(t[o].cnt[i])res+=1<<i;
    return res;
}
void update(int o){
    for(int i=0;i<16;i++)t[o].res[i]=0;
    if(t[o].l==t[o].r)t[o].res[t[o].s]=1;
    else for(int i=0;i<16;i++)t[o].res[i|t[o].s]+=t[ls].res[i]+t[rs].res[i];
}
bool modify(int o,int lt,int rt,int op,int ad){
//由于更新的常数过大,所以我们选择有变化的时候再更新,卡一卡常数
    int l=t[o].l,r=t[o].r;
    if(l>=lt&&r<=rt){
        int ts=t[o].s;
        t[o].cnt[op]+=ad;
        t[o].s=get_s(o);
        if(ts==t[o].s)return 0;
        update(o);
        return 1;
    }
    bool fg=0;
    if(lt<=mid)fg|=modify(ls,lt,rt,op,ad);
    if(rt>mid)fg|=modify(rs,lt,rt,op,ad);
    if(fg)update(o);
    return fg;
}
void init(){
    in(n);in(K);
    for(int i=1;i<=n;i++){
        in(a[i]);
        lst[i]=mp[a[i]];
        mp[a[i]]=i;
    }
    build(1,1,n);
}
void work(){
    ll1 res=0;
    int cnt=0,x;
    for(int i=1;i<=n;i++){
        cnt=1;x=i;
        for(;cnt<=K&&x;x=lst[x],cnt++){
            modify(1,lst[x]+1,x,cnt-1,1);
        }
        cnt=1;x=lst[i];
        for(;cnt<=K&&x;x=lst[x],cnt++){
            modify(1,lst[x]+1,x,cnt-1,-1);
        }
        res+=t[1].res[(1<<K)-1];
    }
    printf("%lld",res);
}
signed main(){
    init(),work();
    return 0;
}