平衡树(splay)+树套树 学习笔记
1、rotate函数
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=(tr[y].s[1]==x);
tr[z].s[(tr[z].s[1]==y)]=x,tr[x].p=z; //1
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y; //2
tr[x].s[k^1]=y,tr[y].p=x; //3
pushup(y),pushup(x);
}
2、splay函数
void splay(int x,int k)
{
while (tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if (z!=k)
{
if ((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x); //2
else rotate(y); //1
}
rotate(x);
}
if (!k) root=x;
}
树套树模板:
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=2000010,INF=1e8+10;
int n,m;
int w[N];
struct node
{
int v,p,s[2];
int sz;
void init(int tv,int tp)
{
v=tv,p=tp;
sz=1;
}
} tr[N];
int idx;
void pushup(int u)
{
tr[u].sz=tr[tr[u].s[0]].sz+tr[tr[u].s[1]].sz+1;
}
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=(tr[y].s[1]==x);
tr[z].s[(tr[z].s[1]==y)]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int &root,int x,int k)
{
while (tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if (z!=k)
{
if ((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if (!k) root=x;
}
void Insert(int &root,int v)
{
int u=root,p=0;
while (u)
{
p=u;
u=tr[u].s[v>tr[u].v];
}
u=++idx;
if (p) tr[p].s[v>tr[p].v]=u;
tr[u].init(v,p);
splay(root,u,0);
}
int get_rank(int root,int v)
{
int u=root;
int rk=0;
while (u)
{
if (tr[u].v<v)
{
rk+=tr[tr[u].s[0]].sz+1;
u=tr[u].s[1];
}
else u=tr[u].s[0];
}
return rk;
}
void Delete(int &root,int v)
{
int u=root;
while (u)
{
if (tr[u].v==v) break;
if (tr[u].v<v) u=tr[u].s[1];
else u=tr[u].s[0];
}
splay(root,u,0);
int l=tr[u].s[0],r=tr[u].s[1];
while (tr[l].s[1]) l=tr[l].s[1];
while (tr[r].s[0]) r=tr[r].s[0];
splay(root,l,0),splay(root,r,l);
tr[r].s[0]=0;
pushup(r),pushup(l);
}
void change(int &root,int v1,int v2)
{
Delete(root,v1);
Insert(root,v2);
}
int get_pre(int root,int v)
{
int u=root,res=-INF;
while (u)
{
if (tr[u].v<v)
{
res=max(res,tr[u].v);
u=tr[u].s[1];
}
else u=tr[u].s[0];
}
return res;
}
int get_next(int root,int v)
{
int u=root,res=INF;
while (u)
{
if (tr[u].v>v)
{
res=min(res,tr[u].v);
u=tr[u].s[0];
}
else u=tr[u].s[1];
}
return res;
}
struct Node
{
int l,r;
int rt;
} Tr[N];
void build(int u,int l,int r)
{
Tr[u].l=l,Tr[u].r=r;
Insert(Tr[u].rt,INF),Insert(Tr[u].rt,-INF);
for (int i=l;i<=r;i++) Insert(Tr[u].rt,w[i]);
if (l==r) return;
int mid=l+(r-l)/2;
build(u*2,l,mid);
build(u*2+1,mid+1,r);
}
void modify(int u,int x,int v)
{
change(Tr[u].rt,w[x],v);
if (Tr[u].l==Tr[u].r) return;
int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
if (x<=mid) modify(u*2,x,v);
else modify(u*2+1,x,v);
}
int query_cnt(int u,int l,int r,int x)
{
if (l<=Tr[u].l && Tr[u].r<=r) return get_rank(Tr[u].rt,x)-1;
int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
int res=0;
if (l<=mid) res+=query_cnt(u*2,l,r,x);
if (r>mid) res+=query_cnt(u*2+1,l,r,x);
return res;
}
int query_pre(int u,int l,int r,int x)
{
if (l<=Tr[u].l && Tr[u].r<=r) return get_pre(Tr[u].rt,x);
int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
int res=-INF;
if (l<=mid) res=max(res,query_pre(u*2,l,r,x));
if (r>mid) res=max(res,query_pre(u*2+1,l,r,x));
return res;
}
int query_next(int u,int l,int r,int x)
{
if (l<=Tr[u].l && Tr[u].r<=r) return get_next(Tr[u].rt,x);
int mid=Tr[u].l+(Tr[u].r-Tr[u].l)/2;
int res=INF;
if (l<=mid) res=min(res,query_next(u*2,l,r,x));
if (r>mid) res=min(res,query_next(u*2+1,l,r,x));
return res;
}
int main()
{
cin >> n >> m;
for (int i=1;i<=n;i++) cin >> w[i];
build(1,1,n);
while (m--)
{
int op;
cin >> op;
if (op==1)
{
int L,R,x;
cin >> L >> R >> x;
int res=query_cnt(1,L,R,x)+1;
cout << res << "\n";
}
else if (op==2)
{
int L,R,k;
cin >> L >> R >> k;
int l=-1,r=INF;
int res;
while (l<=r)
{
int mid=l+(r-l)/2;
if (query_cnt(1,L,R,mid)+1<=k)
{
res=mid;
l=mid+1;
}
else r=mid-1;
}
cout << res << "\n";
}
else if (op==3)
{
int pos,x;
cin >> pos >> x;
modify(1,pos,x);
w[pos]=x;
}
else if (op==4)
{
int L,R,x;
cin >> L >> R >> x;
int res=query_pre(1,L,R,x);
if (res==-INF) puts("-2147483647");
else cout << res << "\n";
}
else
{
int L,R,x;
cin >> L >> R >> x;
int res=query_next(1,L,R,x);
if (res==INF) puts("2147483647");
else cout << res << "\n";
}
}
return 0;
}