【模板】Splay

· · 个人记录

Splay 是平衡树的一种,拥有强大的功能。

  1. 功能简介

    主要表现:维护数列

    • 插入元素

    • 删除元素

    • 区间修改(例如区间加,区间赋值等)

    • 区间询问(例如区间最大值,区间和)

    • 区间翻转

  2. 我的理解

    • 是一颗相对平衡的二叉搜索树

    • 维护相对平衡的关键是双旋 ( splay )

    • 可以像线段树一样灵活地维护(下放)标记

    • 而平衡树优秀的时空复杂度和灵活程度,决定了其强大的功能

  3. 被坑的地方

    • 定义了全局变量 x ,又定义的一个局部变量 x 。
    • 没有提前插入两个虚拟结点 。
    • 注意 rotate 函数的细节问题。
  4. 复杂度

    O(n logn)

下面贴 splay 的模板题:NOI2005维护数列

pre:找到区间[l,r]的方法:splay(x,root),splay(y,x).
    其中, x=find(l-1), y=find(r+1), find(a)表示 a 在 splay 中的编号。
    这样 y 的左子树就对应区间[l,r].

1.对于插入操作, 我们可以利用 build() 函数, 类似线段树的建树, 再进行插入.

2.对于删除操作, 找到对应的区间, 再进行删除.

3.对于修改操作, 找到对应的区间, 打上修改标记.

4.对于翻转操作, 找到对应的区间, 打上翻转标记.

5.对于求和操作, 找到对应的区间, 返回改结点的子树和.

6.对于最大子段和, 直接 O(1) 输出 root 的最大子段和权值.  

7.下传标记优先级: 赋值 > 翻转及其他. 

8. updata() 同线段树维护最大子段和相同. 

code:
#include<iostream>
#include<cstdio>
using namespace std;
const int N=1002018,INF=1e9+7;
int n, q, top, cnt, rt; bool rev[N], tag[N]; char s[12];
int stk[N], fa[N], c[N][2], v[N], a[N], mx[N], lmx[N], rmx[N], sum[N], siz[N], id[N];
void updata(int x){
    int l=c[x][0],r=c[x][1];
    siz[x]=siz[l]+siz[r]+1;
    sum[x]=sum[l]+sum[r]+v[x];
    mx[x]=max(rmx[l]+v[x]+lmx[r], max(mx[l], mx[r]));
    lmx[x]=max(lmx[l], sum[l]+v[x]+lmx[r]);
    rmx[x]=max(rmx[r], sum[r]+v[x]+rmx[l]);
}
void down(int x){
    int l=c[x][0], r=c[x][1];
    if(tag[x]){
        rev[x]=tag[x]=0;
        if(l)tag[l]=1, sum[l]=siz[l]*(v[l]=v[x]);
        if(r)tag[r]=1, sum[r]=siz[r]*(v[r]=v[x]);
        if(v[x] >= 0){
            if(l)mx[l]=lmx[l]=rmx[l]=sum[l];
            if(r)mx[r]=lmx[r]=rmx[r]=sum[r];
        }
        else{
            if(l)mx[l]=v[x], lmx[l]=rmx[l]=0;
            if(r)mx[r]=v[x], lmx[r]=rmx[r]=0;
        }
    }
    if(rev[x]){
        rev[x]=0;
        rev[l]^=1; rev[r]^=1;
        swap(lmx[l],rmx[l]); swap(lmx[r],rmx[r]);
        swap(c[l][0], c[l][1]); swap(c[r][0], c[r][1]); 
    }
}
int find(int x, int kth){
    down(x);//cout<<"!!!!!"<<endl;
    int l=c[x][0], r=c[x][1];
    if(siz[l]+1==kth)return x;
    if(siz[l]>=kth)return find(l, kth);
    else return find(r, kth-siz[l]-1);
}
inline bool get(int x){
    return c[fa[x]][1]==x;
}
void rotate(int x, int &k){
    int y=fa[x], z=fa[y], w=get(x);
    if(y==k)k=x;else c[z][c[z][1]==y]=x;
    c[y][w]=c[x][w^1]; fa[c[y][w]]=y;
    c[x][w^1]=y; fa[y]=x; fa[x]=z;
    updata(y); updata(x);
}
void splay(int x, int &k){
    while(x != k){
        int y=fa[x], z=fa[y];
        if(y != k)
            rotate((get(x)^get(y))?x:y, k);
        rotate(x, k);
    }
}
void build(int l, int r, int f){
    int mid=(l+r)>>1, now=id[mid], pre=id[f];
    if(l==r){
        lmx[now]=rmx[now]=max(0, a[l]);
        mx[now]=sum[now]=a[l];
        tag[now]=rev[now]=0;
        siz[now]=1;
    }
    if(l<mid)build(l, mid-1, mid);
    if(mid<r)build(mid+1, r, mid);
    v[now]=a[mid]; fa[now]=pre;
    updata(now);
    c[pre][mid>=f]=now;
}
int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void INSERT(int k, int tot){
    for(int i=1;i<=tot;i++)a[i]=read();
    for(int i=1;i<=tot;i++){
        if(top>0)id[i]=stk[top--];
        else id[i]=++cnt;
    }
    build(1,tot,0);
    int z=id[(1+tot)>>1], x=find(rt, k+1), y=find(rt, k+2);
    splay(x, rt); splay(y, c[x][1]);
    fa[z]=y; c[y][0]=z;
    updata(y); updata(x);
}
int split(int k, int tot){
    int x=find(rt, k), y=find(rt, k+tot+1);
    splay(x, rt); splay(y, c[x][1]);
    return c[y][0];
}
void recycle(int x){
    int &l=c[x][0], &r=c[x][1];
    if(l)recycle(l);
    if(r)recycle(r);
    stk[++top]=x;
    fa[x]=l=r=tag[x]=rev[x]=0;
}
void DELETE(int k, int tot){
    int x=split(k, tot), y=fa[x];
    recycle(x); c[y][0]=0;
    updata(y); updata(fa[y]);
}
void MODIFY(int k, int tot, int val){
    int x=split(k, tot), y=fa[x];
    tag[x] = 1;v[x]=val;
    sum[x]=siz[x]*val;
    if(val >= 0)lmx[x]=rmx[x]=mx[x]=sum[x];
    else lmx[x]=rmx[x]=0, mx[x]=val;
    updata(y); updata(fa[y]);
}
void RAVER(int k, int tot){
    int x=split(k, tot), y=fa[x];
    if(!tag[x]){
        rev[x]^=1;
        swap(c[x][0], c[x][1]);
        swap(lmx[x], rmx[x]);
        updata(y); updata(fa[y]);
    }
}
int QUERY(int k,int tot){
    int x=split(k, tot);
    return sum[x];
}
void dfs(int x){
    down(x);
    if(c[x][0])dfs(c[x][0]);
    if(c[x][1])dfs(c[x][1]);
}
int main()
{
    n=read();q=read();
    for(int i=1;i<=n;i++)a[i+1]=read();
    mx[0]=a[1]=a[n+2]=-INF;
    for(int i=1;i<=n+2;i++)id[i]=i;
    build(1,n+2,0);
    rt=(n+3)>>1;cnt=n+2;
    while(q--){
        scanf("%s", s);
        if(s[0]=='M' && s[2]=='X'){printf("%d\n", mx[rt]);continue;}
        int x=read(),tot=read();
        if(s[0] == 'I')INSERT(x, tot);
        if(s[0] == 'D')DELETE(x, tot);
        if(s[0] == 'M'){
            int val=read();
            MODIFY(x, tot, val);
        }
        if(s[0] == 'R')RAVER(x, tot);
        if(s[0] == 'G')printf("%d\n", QUERY(x, tot));
    }
    return 0;
}

ps:留坑慢慢补