平衡树(splay)+树套树 学习笔记

· · 算法·理论

1、rotate函数

void rotate(int x)
{
    int y=tr[x].p,z=tr[y].p;
    int k=(tr[y].s[1]==x);
    tr[z].s[(tr[z].s[1]==y)]=x,tr[x].p=z;  //1
    tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y; //2
    tr[x].s[k^1]=y,tr[y].p=x; //3
    pushup(y),pushup(x);
}

2、splay函数

void splay(int x,int k)
{
    while (tr[x].p!=k)
    {
        int y=tr[x].p,z=tr[y].p;
        if (z!=k)
        {
            if ((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x); //2
            else rotate(y); //1
        }
        rotate(x);
    }
    if (!k) root=x;
}

树套树模板:

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N=2000010,INF=1e8+10;
int n,m;
int w[N];
struct node
{
    int v,p,s[2];
    int sz;

    void init(int tv,int tp)
    {
        v=tv,p=tp;
        sz=1;
    }
} tr[N];
int idx;

void pushup(int u)
{
    tr[u].sz=tr[tr[u].s[0]].sz+tr[tr[u].s[1]].sz+1;
}

void rotate(int x)
{
    int y=tr[x].p,z=tr[y].p;
    int k=(tr[y].s[1]==x);
    tr[z].s[(tr[z].s[1]==y)]=x,tr[x].p=z;
    tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
    tr[x].s[k^1]=y,tr[y].p=x;
    pushup(y),pushup(x);
}

void splay(int &root,int x,int k)
{
    while (tr[x].p!=k)
    {
        int y=tr[x].p,z=tr[y].p;
        if (z!=k)
        {
            if ((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
    if (!k) root=x;
}

void Insert(int &root,int v)
{
    int u=root,p=0;
    while (u) 
    {
        p=u;
        u=tr[u].s[v>tr[u].v];
    }
    u=++idx;
    if (p) tr[p].s[v>tr[p].v]=u;
    tr[u].init(v,p);
    splay(root,u,0);
}

int get_rank(int root,int v)
{
    int u=root;
    int rk=0;
    while (u)
    {
        if (tr[u].v<v) 
        {
            rk+=tr[tr[u].s[0]].sz+1;
            u=tr[u].s[1];
        }
        else u=tr[u].s[0];
    }
    return rk;
}

void Delete(int &root,int v)
{
    int u=root;
    while (u)
    {
        if (tr[u].v==v) break;
        if (tr[u].v<v) u=tr[u].s[1];
        else u=tr[u].s[0];
    }
    splay(root,u,0);
    int l=tr[u].s[0],r=tr[u].s[1];
    while (tr[l].s[1]) l=tr[l].s[1];
    while (tr[r].s[0]) r=tr[r].s[0];
    splay(root,l,0),splay(root,r,l);
    tr[r].s[0]=0;
    pushup(r),pushup(l);
}

void change(int &root,int v1,int v2)
{
    Delete(root,v1);
    Insert(root,v2);
}

int get_pre(int root,int v)
{
    int u=root,res=-INF;
    while (u)
    {
        if (tr[u].v<v) 
        {
            res=max(res,tr[u].v);
            u=tr[u].s[1];
        }
        else u=tr[u].s[0];
    }
    return res;
}

int get_next(int root,int v)
{
    int u=root,res=INF;
    while (u)
    {
        if (tr[u].v>v)
        {
            res=min(res,tr[u].v);
            u=tr[u].s[0];
        }
        else u=tr[u].s[1];
    }
    return res;
}

struct Node
{
    int l,r;
    int rt;
} Tr[N];

void build(int u,int l,int r)
{
    Tr[u].l=l,Tr[u].r=r;
    Insert(Tr[u].rt,INF),Insert(Tr[u].rt,-INF);
    for (int i=l;i<=r;i++) Insert(Tr[u].rt,w[i]);
    if (l==r) return;
    int mid=l+(r-l)/2;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
}

void modify(int u,int x,int v)
{
    change(Tr[u].rt,w[x],v);
    if (Tr[u].l==Tr[u].r) return;
    int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
    if (x<=mid) modify(u*2,x,v);
    else modify(u*2+1,x,v);
}

int query_cnt(int u,int l,int r,int x)
{
    if (l<=Tr[u].l && Tr[u].r<=r) return get_rank(Tr[u].rt,x)-1;
    int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
    int res=0;
    if (l<=mid) res+=query_cnt(u*2,l,r,x);
    if (r>mid) res+=query_cnt(u*2+1,l,r,x);
    return res;
}

int query_pre(int u,int l,int r,int x)
{
    if (l<=Tr[u].l && Tr[u].r<=r) return get_pre(Tr[u].rt,x);
    int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
    int res=-INF;
    if (l<=mid) res=max(res,query_pre(u*2,l,r,x));
    if (r>mid) res=max(res,query_pre(u*2+1,l,r,x));
    return res;
}

int query_next(int u,int l,int r,int x)
{
    if (l<=Tr[u].l && Tr[u].r<=r) return get_next(Tr[u].rt,x);
    int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
    int res=INF;
    if (l<=mid) res=min(res,query_next(u*2,l,r,x));
    if (r>mid) res=min(res,query_next(u*2+1,l,r,x));
    return res;
}

int main()
{   
    cin >> n >> m;
    for (int i=1;i<=n;i++) cin >> w[i];

    build(1,1,n);

    while (m--)
    {
        int op;
        cin >> op;
        if (op==1)
        {
            int L,R,x;
            cin >> L >> R >> x;
            int res=query_cnt(1,L,R,x)+1;
            cout << res << "\n";
        }
        else if (op==2)
        {
            int L,R,k;
            cin >> L >> R >> k;

            int l=-1,r=INF;
            int res;
            while (l<=r)
            {
                int mid=l+(r-l)/2;
                if (query_cnt(1,L,R,mid)+1<=k)
                {
                    res=mid;
                    l=mid+1;
                }
                else r=mid-1;
            }
            cout << res << "\n";
        }
        else if (op==3)
        {
            int pos,x;
            cin >> pos >> x;
            modify(1,pos,x);
            w[pos]=x;
        }
        else if (op==4)
        {
            int L,R,x;
            cin >> L >> R >> x;
            int res=query_pre(1,L,R,x);

            if (res==-INF) puts("-2147483647");
            else cout << res << "\n";
        }
        else
        {
            int L,R,x;
            cin >> L >> R >> x;
            int res=query_next(1,L,R,x);

            if (res==INF) puts("2147483647");
            else cout << res << "\n";
        }
    }

    return 0;
}