【乱来】 树套树

· · 算法·理论

基本

顾名思义,就是通过不同数据结构的套嵌产生一种新的数据结构,使得这个数据结构拥有更多的作用。

树套树的常数与空间巨大,谨慎使用。

线段树套平衡树

比较常见,一般用于处理区间第 k 排名,区间前驱、后继,区间某个数排名等问题。支持在线处理。

具体地,就是在每个线段树节点上开一棵平衡树,那么每一层的空间为 \Omicron(n)

由于套了两层,线段树套平衡树的空间与时间复杂度都是 \Omicron(n\log^2 n) 的。

特殊地,在求区间第 k 小时,其复杂度是 \Omicron(n\log^3 n) 的,虽然严格跑不满,但还是要注意。

例题

P3380 【模板】树套树

要求能够求区间某个数的排名是多少,区间第 k 小,在线修改,区间某个数的前驱 / 后继。

我们对每个线段树结点的对应区间开完平衡树后逐个解决:

然后就很板地写完了。这个东西最好卡常(不过 4s 懒得卡了)。

#include<bits/stdc++.h>
//#define ll long long
#define L (p<<1)
#define R (p<<1|1)
using namespace std;
template<typename T> inline void read(T &x){
    T w=1;
    x=0;
    char c=getchar();
    while(!isdigit(c)){
        if(c=='-') w=-1;
        c=getchar();
    }
    while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48),c=getchar();
    x*=w;
}
template<typename T> inline void write(T x){
    if(x<0) putchar('-'),x=(~x)+1;
    if(x>9) write(x/10);
    putchar(x%10^48);
}
const int N=5e4+10,inf=INT_MAX;
struct Splay{
    #define ch(p,q) (tr[p].ch[q])
    #define fa(p) (tr[p].fa)
    #define val(p) (tr[p].val)
    #define same(p) (tr[p].same)
    #define siz(p) (tr[p].siz)
    struct tree{
        int ch[2],fa;
        int same,siz,val;
    }tr[N*50];
    int id,rt[N<<2];
    int newnode(int val){
        tr[++id]={{0,0},0,1,1,val};
        return id;
    }
    void init(int p){
        int a=newnode(inf),b=newnode(-inf);
        tr[a].siz++,ch(a,0)=b,fa(b)=a;
        rt[p]=a;
        return ;
    }
    int ges(int x){
        return ch(fa(x),1)==x;
    }
    void update(int x){
        siz(x)=same(x);
        if(ch(x,0)) siz(x)+=siz(ch(x,0));
        if(ch(x,1)) siz(x)+=siz(ch(x,1));
    }
    void rotate(int x){
        int y=fa(x),z=fa(y);
        if(z) ch(z,ges(y))=x;
        int k=ges(x);
        ch(y,k)=ch(x,k^1),fa(ch(x,k^1))=y;
        ch(x,k^1)=y,fa(y)=x;
        fa(x)=z;
        update(y),update(x);
        return ;
    }
    void splay(int p,int x,int T){
        while(fa(x)!=T){
            int y=fa(x),z=fa(y);
    //      cout<<y<<" "<<z<<'\n';
            if(z!=T) rotate(ges(x)==ges(y) ? y : x);
            rotate(x);
    //      Sleep(1000);
        }
        if(!T) rt[p]=x; 
    }
    int pred(int now,int sum,int op){
        if(!now) return op;
        if(val(now)<sum) return pred(ch(now,1),sum,now);
        else return pred(ch(now,0),sum,op);
    }
    int succ(int now,int sum,int op){
        if(!now) return op;
        if(val(now)>sum) return succ(ch(now,0),sum,now);
        else return succ(ch(now,1),sum,op);
    }
    void ins(int p,int sum){
        int x=pred(rt[p],sum,0);
        splay(p,x,0);
    //  cout<<val(x)<<"\n";
    //  cout<<rt<<" "<<x<<"\n";
        int y=succ(rt[p],sum,0);
        splay(p,y,x);
        if(!ch(y,0)){
            int k=newnode(sum);
            ch(y,0)=k;
            fa(k)=y;
        }
        else same(ch(y,0))++,update(ch(y,0));
        update(y),update(x);
    }
    void delet(int p,int sum){
        int x=pred(rt[p],sum,0);
        splay(p,x,0);
        int y=succ(rt[p],sum,0);
        splay(p,y,x);
        same(ch(y,0))--;
        if(ch(y,0)&&val(ch(y,0))==sum){
            if(!same(ch(y,0))) ch(y,0)=0;
            else update(ch(y,0));
        }
        update(y),update(x);
    }
    int ran(int p,int sum){
        int x=pred(rt[p],sum,0);
        splay(p,x,0);
        int y=succ(rt[p],sum,0);
        splay(p,y,x);
        return siz(ch(x,0))+same(x)+1;
    }
}Tr;
int n,m,a[N];
void build(int l,int r,int p){
    Tr.init(p);
    for(int i=l;i<=r;i++) Tr.ins(p,a[i]);
    if(l==r) return ;
    int mid=l+r>>1;
    build(l,mid,L),build(mid+1,r,R);
    return ; 
}
int que_ran(int l,int r,int x,int y,int z,int p){
    if(x<=l&&r<=y){
        int res=Tr.ran(p,z)-2;
        return res;
    }
    int mid=l+r>>1,res=0;
    if(mid>=x) res+=que_ran(l,mid,x,y,z,L);
    if(mid<y) res+=que_ran(mid+1,r,x,y,z,R);
    return res;
}
void modify(int l,int r,int x,int z,int p){
    Tr.delet(p,a[x]),Tr.ins(p,z);
    if(l==r) return ;
    int mid=l+r>>1;
    if(mid>=x) modify(l,mid,x,z,L);
    else modify(mid+1,r,x,z,R);
    return ; 
}
int pre(int l,int r,int x,int y,int z,int p){
    if(x<=l&&r<=y) return Tr.tr[Tr.pred(Tr.rt[p],z,0)].val;
    int mid=l+r>>1,res=-inf;
    if(mid>=x) res=max(res,pre(l,mid,x,y,z,L));
    if(mid<y) res=max(res,pre(mid+1,r,x,y,z,R));
    return res;
}
int suc(int l,int r,int x,int y,int z,int p){
    if(x<=l&&r<=y) return Tr.tr[Tr.succ(Tr.rt[p],z,0)].val;
    int mid=l+r>>1,res=inf;
    if(mid>=x) res=min(res,suc(l,mid,x,y,z,L));
    if(mid<y) res=min(res,suc(mid+1,r,x,y,z,R));
    return res;
}
int main(){
//  freopen(".in","r",stdin);
//  freopen(".out","w",stdout);
    read(n),read(m);
    for(int i=1;i<=n;i++) read(a[i]);
    build(1,n,1);
    while(m--){
        int op,l,r,k;
        read(op);
        if(op==1){
            read(l),read(r),read(k);
            write(que_ran(1,n,l,r,k,1)+1);
            putchar('\n');
        }
        if(op==2){
            read(l),read(r),read(k);
            int LL=-1,RR=1e8+1;
            while(LL<RR){
                int mid=LL+RR>>1;
                if(que_ran(1,n,l,r,mid,1)+1>=k+1) RR=mid;
                else LL=mid+1; 
            }
            write(RR-1);
            putchar('\n');
        }
        if(op==3){
            read(l),read(k);
            modify(1,n,l,k,1);
            a[l]=k;
        }
        if(op==4){
            read(l),read(r),read(k);
            write(pre(1,n,l,r,k,1));
            putchar('\n');
        }
        if(op==5){
            read(l),read(r),read(k);
            write(suc(1,n,l,r,k,1));
            putchar('\n');
        }
    }
    return 0;
}
/*
9 3
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3

9 3
4 2 2 1 9 4 0 1 1
1 1 1 1
3 1 10
2 1 1 1

9 2
4 2 2 1 9 4 0 1 1
3 4 10
5 2 8 5
*/