平衡树

· · 个人记录

diu个splayvector模板...emmm

luoguP3369普通平衡树

splay模板

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define re register
#define lson t[u].ch[0]
#define rson t[u].ch[1]
#define inf 0x7f7f7f7f
using namespace std;

int N;
int tot;
int root;
struct node
{
    int fa;
    int val;
    int ch[2];
    int siz,cnt;
} t[1000010];
int read(){int x;scanf("%d",&x);return x;}
void updata(int u)
{
    t[u].cnt=t[lson].cnt+t[rson].cnt+t[u].siz;//分清楚cnt和siz...
    return ;
}
void zig(int x)
{
    int y=t[x].fa;
    int tmp=t[x].ch[0];
    t[y].fa=x;t[x].ch[0]=y;
    t[tmp].fa=y;t[y].ch[1]=tmp;
    return ; 
}
void zag(int x)
{
    int y=t[x].fa;
    int tmp=t[x].ch[1];
    t[y].fa=x;t[x].ch[1]=y;
    t[tmp].fa=y;t[y].ch[0]=tmp;//当发现和上面的不对称的时候 你心里还没点13数吗...
    return ;
}
void rotate(int x)
{
    int y=t[x].fa;
    int z=t[y].fa;
    int flag=(t[y].val>t[z].val);
    if(t[x].val>t[y].val) zig(x);
    else zag(x);
    t[x].fa=z;t[z].ch[flag]=x;//看清楚x和z...
    updata(y);updata(x);
    return ;
}
void splay(int x,int goal)
{
    int y,z;
    while(t[x].fa!=goal)
    {
        y=t[x].fa;
        z=t[y].fa;
        if(z!=goal)
        {
            if((y==t[z].ch[0])^(x==t[y].ch[0]))//不是一条链
                rotate(x);
            else rotate(y); 
        }
        rotate(x);  
    }
    if(!goal) root=x;
    return ;
}
void insert(int x)
{
    int u=root,f=0;
    while(u&&t[u].val!=x){f=u;u=t[u].ch[t[u].val<x];}
    if(u) t[u].siz++,t[u].cnt++;
    else
    {
        u=++tot;t[u].val=x;
        t[u].siz=t[u].cnt=1;
        t[u].fa=f;if(f) t[f].ch[t[f].val<x]=u;
    }
    splay(u,0);return ;
}
void find(int x)
{
    int u=root;
    while(t[u].ch[t[u].val<x]&&t[u].val!=x) u=t[u].ch[t[u].val<x];
    splay(u,0);return ;
}
int nxt(int x,int flag)
{
    find(x);
    int u=root;
    if((t[u].val<x&&!flag)||(t[u].val>x&&flag)) return u;
    u=t[u].ch[flag];//不是while!!!
    while(t[u].ch[flag^1]) u=t[u].ch[flag^1];
    return u;
}
void del(int x)
{
    int t1=nxt(x,0);//找到前驱的位置
    int t2=nxt(x,1);//找到后继的位置
    splay(t1,0);splay(t2,t1);
    if(t[t[t2].ch[0]].siz>1) t[t[t2].ch[0]].siz--,t[t[t2].ch[0]].cnt--;
    else t[t2].ch[0]=0;
    return ;
}
int search(int x)
{
    int u=root;
    while(1)
    {
        if(t[lson].cnt>=x) u=lson;
        else if(t[lson].cnt+t[u].siz<x)
            {x-=t[lson].cnt+t[u].siz;u=rson;}
        else {splay(u,0);return u;}
    }
}
int main()
{
    N=read();
    re int i,opt,x;
    insert(-inf);insert(inf);
    for(i=1;i<=N;i++)
    {
        opt=read();x=read();
        if(opt==1) insert(x);
        else if(opt==2) del(x);
        else if(opt==3){find(x);printf("%d\n",t[t[root].ch[0]].cnt);}
        else if(opt==4) printf("%d\n",t[search(x+1)].val);
        else if(opt==5) printf("%d\n",t[nxt(x,0)].val);
        else printf("%d\n",t[nxt(x,1)].val);
    }
    return 0;
}
//注意区分“寻找数x”和“寻找第x个数”

vector模板

#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define ll long long
#define re register
using namespace std;

int N; 
vector<int> t;
int read(){int x;scanf("%d",&x);return x;}
int main()
{
    N=read();
    re int i,opt,x;
    for(i=1;i<=N;i++)
    {
        opt=read();x=read();
        if(opt==1) t.insert(upper_bound(t.begin(),t.end(),x),x);
        else if(opt==2) t.erase(upper_bound(t.begin(),t.end(),x)-1);
        else if(opt==3) printf("%d\n",lower_bound(t.begin(),t.end(),x)-t.begin()+1);
        else if(opt==4) printf("%d\n",t.at(x-1));
        else if(opt==5) printf("%d\n",t.at(lower_bound(t.begin(),t.end(),x)-t.begin()-1));
        else printf("%d\n",t.at(upper_bound(t.begin(),t.end(),x)-t.begin()));
    }
    return 0;
}

luoguP2286宠物收养场

splay版

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const long long inf=9999999999999999;
const int mod=1000000;
int N;
int tot;
int root;
struct node
{
    int ch[2],fa;
    long long val; 
    int cnt,size;
} t[100010];
int flag;
long long Ans;
void zig(int x)
{
    int y=t[x].fa;
    int tmp=t[x].ch[0];
    t[x].ch[0]=y; t[y].fa=x;
    t[y].ch[1]=tmp; t[tmp].fa=y;
    return ;
}
void zag(int x)
{
    int y=t[x].fa;
    int tmp=t[x].ch[1];
    t[x].ch[1]=y; t[y].fa=x;
    t[y].ch[0]=tmp; t[tmp].fa=y;
    return ;
}
void updata(int id)
{
    t[id].size=t[t[id].ch[0]].size+t[t[id].ch[1]].size+t[id].cnt;
    return ;
}
void rotate(int x)
{
    int y=t[x].fa;
    int z=t[y].fa;
    int tmp=t[y].val>t[z].val;
    if(t[x].val>t[y].val) zig(x);
    else zag(x);
    t[x].fa=z; t[z].ch[tmp]=x;
    updata(y); updata(x);
    return ;
}
void splay(int x,int goal)
{
    while(t[x].fa!=goal)
    {
        int y=t[x].fa;
        int z=t[y].fa;
        if(z!=goal)
        {
            if((t[y].ch[0]==x)^(t[z].ch[1]==y)) rotate(y);//同为左儿子或右儿子 
            else rotate(x);//一左一右 
        }
        rotate(x);
    }
    if(!goal) root=x;
    return ;
}
void insert(long long x)
{
    int u=root,f=0;
    while(u && t[u].val!=x) 
    {
        f=u;
        u=t[u].ch[t[u].val<x];
    }
    if(u) t[u].cnt++,t[u].size++;
    else
    {
        u=++tot;
        if(f) t[f].ch[t[f].val<x]=u;
        t[u].fa=f; t[u].val=x;
        t[u].size++; t[u].cnt++;
    }
    splay(u,0);
    return ;
}
void find(long long x)
{
    int u=root;
    if(!u) return ;
    while(t[u].val!=x && t[u].ch[t[u].val<x]) 
        u=t[u].ch[t[u].val<x];
    splay(u,0);
    return ;
}
int nxt(long long x,int f)
{
    find(x);
    if((t[root].val<x && !f) || (t[root].val>x && f)) 
    return root;
    int u=t[root].ch[f];
    while(t[u].ch[f^1]) u=t[u].ch[f^1];
    return u;
}
void del(long long x)
{
    int t1=nxt(x,0);
    int t2=nxt(x,1);
    splay(t1,0); splay(t2,t1);
    if(t[t[t2].ch[0]].cnt>1)
    {
        t[t[t2].ch[0]].cnt--;
        t[t[t2].ch[0]].size--;
        splay(t[t2].ch[0],0);
    }
    else t[t2].ch[0]=0;
    return ;
}
long long get(long long now)
{
    return now>0 ? now : -now;
}
int main()
{
    scanf("%d",&N);
    insert(-inf); insert(inf);
    for(int i=1;i<=N;i++)
    {
        int opt;
        long long num;
        scanf("%d%lld",&opt,&num);
        if(!opt) opt=-1;
        if(!flag || flag/opt>0)
            insert(num);
        else
        {
            int a1=t[nxt(num,0)].val;
            int a2=t[nxt(num,1)].val;
            if(get(a1-num)<=get(a2-num))
            {
                Ans+=get(a1-num);
                del(a1); 
            }
            else 
            {
                Ans+=get(a2-num);
                del(a2);
            }
        }
        Ans%=mod;
        flag+=opt;
    }
    printf("%lld",Ans);
    return 0;
}

vector版

#include<cstdio>
#include<vector> 
#include<cstring>
#include<algorithm>
#define ll long long
#define re register
using namespace std;
const int mod=1000000;
const int inf=0x7f7f7f7f;

int Ans;
int cnt,N;
vector<int> t;
int read(){int x;scanf("%d",&x);return x;}
int main()
{
    N=read();
    re int i,opt;
    re int x;
    re int t1,t2;
    re int x1,x2;
    t.insert(lower_bound(t.begin(),t.end(),-inf),-inf);
    t.insert(lower_bound(t.begin(),t.end(),inf),inf);
    for(i=1;i<=N;i++)
    {
        opt=read();x=read();
        if(opt)
        {
            cnt++;
            if(cnt>0) t.insert(lower_bound(t.begin(),t.end(),x),x);
            else
            {
                t1=lower_bound(t.begin(),t.end(),x)-t.begin();//next
                t2=lower_bound(t.begin(),t.end(),x)-t.begin()-1;//last
                x1=t[t1];x2=t[t2];
                if(x2==-inf||x1-x<x-x2) Ans+=x1-x,t.erase(lower_bound(t.begin(),t.end(),x1));
                //一定要特判-inf...因为x-(-inf)会爆int...然后就炸了qaq
                else Ans+=x-x2,t.erase(lower_bound(t.begin(),t.end(),x2));
                Ans%=mod;
            }
        }
        else
        {
            cnt--;
            if(cnt<0) t.insert(lower_bound(t.begin(),t.end(),x),x);
            else
            {
                t1=lower_bound(t.begin(),t.end(),x)-t.begin();
                t2=lower_bound(t.begin(),t.end(),x)-t.begin()-1;
                x1=t[t1];x2=t[t2];
                if(x2==-inf||x1-x<x-x2) Ans+=x1-x,t.erase(lower_bound(t.begin(),t.end(),x1));
                else Ans+=x-x2,t.erase(lower_bound(t.begin(),t.end(),x2));
                Ans%=mod;               
            }
        }
    }
    printf("%d",Ans);
    return 0;
}