树套树

· · 个人记录

树套树

在学习 二维线段树 时,介绍了线段树套线段树。

同时确保您已经学会 平衡树。

P3380 【模板】二逼平衡树(树套树)

题目让我们维护的 5 个操作中,但如果没有要求区间,那么就是平衡树模板题,但是要求了区间后,我们就要找区间之间的关系,可以用线段树维护区间之间直接的关系。和二维线段树相同的,先开一个线段树,然后每个节点维护一个平衡树。然后就是基础的模板操作了(只是难写了亿点点)。

代码

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
int r_r(){//快读 
    int k=0,f=1;
    char c=getchar();
    while(!isdigit(c)){
        if(c=='-')f=-1;
        c=getchar();
    }
    while(isdigit(c)){
        k=(k<<1)+(k<<3)+(c^48);
        c=getchar();
    }
    return k*f;
}
const int o_o=5e4+10;
const int m_a=2147483647;//一定设到极限大(0x3f3f3f3f 会被卡) 
struct sp{
    int s_z;//树的大小 
    int n;//值相同节点数量 
    int v_l;//节点价值 
    int s_n[2];//左右儿子 
    int f_a;//父节点 
}t_s[o_o*40];//平衡树 
struct tr{
    int l;//左儿子 
    int r;//右儿子 
    int g_g;//平衡树根节点 
}t_r[o_o<<2];//线段树 
int a_a[o_o];//原序列 
int n,m,x_p;
int l_l(int k){//左子树 
    return k<<1;
}
int r_r(int k){//右子树 
    return k<<1|1;
}
void u_p(int x){//更新子树大小 
    t_s[x].s_z=t_s[t_s[x].s_n[0]].s_z+t_s[t_s[x].s_n[1]].s_z+t_s[x].n;
} 
void t_n(int x){//旋转平衡树 
    int f=t_s[x].f_a;
    int f_f=t_s[f].f_a;
    int k_k=(t_s[f].s_n[1]==x);//判断节点在父节点的方位 

    //旋转操作 
    t_s[f_f].s_n[(t_s[f_f].s_n[1]==f)]=x;
    t_s[x].f_a=f_f;
    t_s[t_s[x].s_n[k_k^1]].f_a=f;
    t_s[f].s_n[k_k]=t_s[x].s_n[k_k^1];
    t_s[x].s_n[k_k^1]=f;
    t_s[f].f_a=x;

    //更新新子节点 
    u_p(f);
    //更新新父节点 
    u_p(x);
}
void s_p(int x,int g_g,int k){
    while(t_s[x].f_a!=g_g){//没有旋转到根 
        int f=t_s[x].f_a;//父节点 
        int f_f=t_s[f].f_a;//爷节点 
        if(f_f!=g_g)(t_s[f_f].s_n[1]==f)^(t_s[f].s_n[1]==x)?t_n(f):t_n(x);
        //同方位儿子转节点父亲,否则转接节点 

        t_n(x);//旋转节点 
    }
    if(g_g==0)t_r[k].g_g=x;//更新根节点 
}
int yv(int v,int f_a){//初始化节点 
    t_s[++x_p].v_l=v;//初值 
    t_s[x_p].f_a=f_a;//祖先 
    t_s[x_p].n=1;//初始化相同节点数量 
    u_p(x_p);//更新子树大小 
    return x_p;
}
void a_d(int v,int k){
    int n_n=t_r[k].g_g;//记录根节点 
    int f_a=0;//初始化父节点 
    if(!n_n){//初始化平衡树的根 
        n_n=yv(v,0);//初始化节点 
        t_r[k].g_g=n_n;//更新根 
        return ;
    }
    while(n_n&&(t_s[n_n].v_l!=v)){//找到与当前值相等的点 
        f_a=n_n;
        if(t_s[n_n].v_l<v)n_n=t_s[n_n].s_n[1];
        else n_n=t_s[n_n].s_n[0];
    }
    if(v==t_s[n_n].v_l&&n_n)t_s[n_n].n++;//找到节点相同的点 
    else{
        n_n=yv(v,f_a);//初始化新节点 
        if(f_a){//判断左右儿子 
            if(t_s[f_a].v_l<v)t_s[f_a].s_n[1]=n_n;
            else t_s[f_a].s_n[0]=n_n;
        }
    }
    s_p(n_n,0,k);//旋转平衡树
}
void f_i(int v,int k){
    int n_n=t_r[k].g_g;//根节点 
    if(!n_n)return ;

    //找排名 
    while(t_s[n_n].s_n[t_s[n_n].v_l<v]&&t_s[n_n].v_l!=v){//子树有节点,并且找前驱所以值不能相等 
        if(t_s[n_n].v_l<v)n_n=t_s[n_n].s_n[1];
        else n_n=t_s[n_n].s_n[0];
    }
    s_p(n_n,0,k);//旋转平衡树 
}
int n_t(int x,int b_b,int k){//b_b 0 后继,b_b 1 前驱 
    f_i(x,k);//找 x 的排名 
    int n_n=t_r[k].g_g;//记录根节点 
    if((b_b&&t_s[n_n].v_l<x)||(!b_b&&t_s[n_n].v_l>x))return n_n;//达到边界,找到目标 
    n_n=t_s[n_n].s_n[b_b^1];//跳过边界,往回找 
    while(t_s[n_n].s_n[b_b])n_n=t_s[n_n].s_n[b_b];//逼近目标 
    return n_n;
}
void d_l(int x,int k){
    int n_n=t_r[k].g_g;//记录根节点 
    int q_q=n_t(x,1,k);//前驱 
    int h_j=n_t(x,0,k);//后继 
    s_p(h_j,0,k);//将后继变为根节点 
    s_p(q_q,h_j,k);//将前驱变为根节点子节点 
    int k_k=t_s[q_q].s_n[1];//目标节点 
    if(t_s[k_k].n>1)--t_s[k_k].n,s_p(k_k,0,k);//有多个相同值,减去一个并将目标节点转到根节点 
    else t_s[q_q].s_n[1]=0;//清空节点 
    u_p(q_q);//更新节点信息 
}
void b_t(int k,int l,int r){
    //控制边界 
    a_d(m_a,k);//加最大节点 
    a_d(-m_a,k);//加最小节点 

    if(l==r)return ;//叶子节点 
    int m_i=(l+r)>>1;
    b_t(l_l(k),l,m_i);//左子树 
    b_t(r_r(k),m_i+1,r);//右子树 
}
void s_d(int k,int l,int r,int i,int v_l){
    int m_i=(l+r)>>1;
    a_d(v_l,k);//节点平衡树加点 
    if(l==r)return ;//到叶子节点返回 
    if(m_i>=i)s_d(l_l(k),l,m_i,i,v_l);//左子树 
    else s_d(r_r(k),m_i+1,r,i,v_l);//右子树 
} 
int s_p(int k,int l,int r,int i,int x,int y){
    if(l>y||r<x)return 0;//超过边界 
    if(l>=x&&r<=y){//在范围内 
        f_i(i,k);//找排名 
        int n_n=t_r[k].g_g;

        //根据子树大小输出排名 
        if(t_s[n_n].v_l>=i)return t_s[t_s[n_n].s_n[0]].s_z-1;
        else return t_s[t_s[n_n].s_n[0]].s_z+t_s[n_n].n-1;
    }
    int m_i=(l+r)>>1;
    return s_p(l_l(k),l,m_i,i,x,y)+s_p(r_r(k),m_i+1,r,i,x,y);//统计左右子树排名 
}
void s_g(int k,int l,int r,int k_l,int v_l){
    d_l(a_a[k_l],k);//删旧点 
    a_d(v_l,k);//补新点 
    if(l==r&&l==k_l){//达到目标节点 
        a_a[k_l]=v_l;//更新 
        return ;
    }
    int m_i=(l+r)>>1;
    if(m_i>=k_l)s_g(l_l(k),l,m_i,k_l,v_l);//左子树 
    else s_g(r_r(k),m_i+1,r,k_l,v_l);//右子树 
}
int s_q(int k,int l,int r,int x,int y,int i){
    if(l>y||r<x)return -m_a;//不在范围内 
    if(l>=x&&r<=y)return t_s[n_t(i,1,k)].v_l;//在范围内 
    int m_i=(l+r)>>1;
    return max(s_q(l_l(k),l,m_i,x,y,i),s_q(r_r(k),m_i+1,r,x,y,i));//返回最大值(越大越逼近) 
}
int s_h(int k,int l,int r,int x,int y,int i){
    if(l>y||r<x)return m_a;//不在范围内 
    if(l>=x&&r<=y)return t_s[n_t(i,0,k)].v_l;//在范围内 
    int m_i=(l+r)>>1;
    return min(s_h(l_l(k),l,m_i,x,y,i),s_h(r_r(k),m_i+1,r,x,y,i));//返回最小值(越小越逼近) 
}
int s_k(int x,int y,int i){
    int l=0,r=1e8,m_i,a_s;
    while(l<=r){//二分找值 
        m_i=(l+r)>>1;
        int b_b=s_p(1,1,n,m_i,x,y)+1;//查数的排名 
        if(b_b>i)r=m_i-1;//超过目标排名 
        else l=m_i+1,a_s=m_i;//记录目前情况 
    }
    return a_s;
}
int main(){
    n=r_r(),m=r_r();
    b_t(1,1,n);//建树 
    for(int i=1;i<=n;++i)a_a[i]=r_r(),s_d(1,1,n,i,a_a[i]);//加点 
    for(int i=1;i<=m;++i){
        int op=r_r(),l=r_r(),r=r_r(),k;
        if(op==1)k=r_r(),printf("%d\n",s_p(1,1,n,k,l,r)+1);//查排名 
        if(op==2)k=r_r(),printf("%d\n",s_k(l,r,k));//查值 
        if(op==3)s_g(1,1,n,l,r);//修改 
        if(op==4)k=r_r(),printf("%d\n",s_q(1,1,n,l,r,k));//前驱 
        if(op==5)k=r_r(),printf("%d\n",s_h(1,1,n,l,r,k));//后继 
    }
    return 0;
}

注意最导致初始化的时候要保证值足够大。

会发现不开 O_2 只能过 3,4 个点,开了 O_2 仍有一个点会被卡掉。那就只能玄学优化了。

我们将建树的过程更改:

void b_t(int k,int l,int r){
    //控制边界 
    a_d(m_a,k);//加最大节点 
    a_d(-m_a,k);//加最小节点 

    for(int i=l;i<=r;i++)a_d(a_a[i],k);//读入范围内节点 
    if(l==r)return ;//叶子节点 
    int m_i=(l+r)>>1;
    b_t(k<<1,l,m_i);//左子树 
    b_t(k<<1|1,m_i+1,r);//右子树 
}

直接将每个线段树节点的值范围办函的点全部读入。

再加两个宏:

#define l_s(x)t_s[x].s_n[0]
#define r_s(x)t_s[x].s_n[1]

来访问左右儿子。

注意多次调用函数,会使效率降低,所以删去

int l_l(int k){//左子树 
    return k<<1;
}
int r_r(int k){//右子树 
    return k<<1|1;
}

直接计算。

还有判断“优化”:

if(t_s[n_n].v_l<v)n_n=t_s[n_n].s_n[1];
else n_n=t_s[n_n].s_n[0];

可以写成:

n_n=t_s[n_n].s_n[t_s[n_n].v_l<v];

最后函数前加上 inline

记得开 O_2

AC 代码

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<vector>
#include<random>
#include<ctime>
using namespace std;
#define il inline
int r_r(){//快读 
    long long k=0,f=1;
    char c=getchar();
    while(!isdigit(c)){
        if(c=='-')f=-1;
        c=getchar();
    }
    while(isdigit(c)){
        k=(k<<1)+(k<<3)+(c^48);
        c=getchar();
    }
    return k*f;
}
const int N=5e4+5;
#define l_s(x)t_s[x].s_n[0]
#define r_s(x)t_s[x].s_n[1]
const int m_a=2147483647;
struct ts{
    int s_z,n,v_l,s_n[2],f_a;
}t_s[N*50];
struct t_r{
    int l,r,g_g;
}t_r[N*4];
int a_a[N],n,m,x_p;
il void u_p(int x){
    t_s[x].s_z=t_s[l_s(x)].s_z+t_s[r_s(x)].s_z+t_s[x].n;
} 
il void t_n(int x){
    int f=t_s[x].f_a;
    int f_f=t_s[f].f_a;
    int b_b=(t_s[f].s_n[1]==x);
    t_s[f_f].s_n[(t_s[f_f].s_n[1]==f)]=x;
    t_s[x].f_a=f_f;
    t_s[t_s[x].s_n[b_b^1]].f_a=f;
    t_s[f].s_n[b_b]=t_s[x].s_n[b_b^1];
    t_s[x].s_n[b_b^1]=f;
    t_s[f].f_a=x;
    u_p(f);
    u_p(x);
}
il void s_p(int x,int g_g,int k){
    while(t_s[x].f_a!=g_g){
        int f=t_s[x].f_a,f_f=t_s[f].f_a;
        if(f_f!=g_g)(t_s[f_f].s_n[1]==f)^(t_s[f].s_n[1]==x)?t_n(x):t_n(f);
        t_n(x);
    }
    if(g_g==0)t_r[k].g_g=x;
}
il int yv(int v,int f_a){
    t_s[++x_p].v_l=v;
    t_s[x_p].f_a=f_a;
    t_s[x_p].n=1;
    u_p(x_p);
    return x_p;
}
il void a_d(int x,int k){
    int n_n=t_r[k].g_g;
    int f_a=0;
    if(!n_n){
        n_n=yv(x,0);
        t_r[k].g_g=n_n;
        return ;
    }
    while(n_n&&(t_s[n_n].v_l != x))f_a=n_n,n_n=t_s[n_n].s_n[t_s[n_n].v_l<x];
    if(x==t_s[n_n].v_l&&n_n)t_s[n_n].n ++;
    else {
        n_n=yv(x,f_a);
        if(f_a)t_s[f_a].s_n[t_s[f_a].v_l<x]=n_n;
    }
    s_p(n_n,0,k );
}
il void f_i(int x,int k){
    int n_n=t_r[k].g_g;
    if(!n_n)return ;
    while(t_s[n_n].s_n[t_s[n_n].v_l<x]&&t_s[n_n].v_l!=x)
        n_n=t_s[n_n].s_n[t_s[n_n].v_l<x];
    s_p(n_n,0,k);
}
il int n_t(int x,int b_b,int k){
    f_i(x,k);
    int n_n=t_r[k].g_g;
    if((b_b&&t_s[n_n].v_l<x)||(!b_b&&t_s[n_n].v_l>x))return n_n;
    n_n=t_s[n_n].s_n[b_b^1];
    while(t_s[n_n].s_n[b_b])n_n=t_s[n_n].s_n[b_b];
    return n_n;
}
il void d_l(int x,int k){
    int n_n=t_r[k].g_g;
    int q_q=n_t(x,1,k);
    int h_j=n_t(x,0,k);
    s_p(h_j,0,k);
    s_p(q_q,h_j,k);
    int k_l=t_s[q_q].s_n[1];
    if(t_s[k_l].n>1){
        --t_s[k_l].n;
        s_p(k_l,0,k);
    }else t_s[q_q].s_n[1]=0;
    u_p(q_q);
}
void b_t(int k,int l,int r){
    a_d(m_a,k),a_d(-m_a,k);
    for(int i=l;i<=r;i++)a_d(a_a[i],k);
    if(l==r)return;
    int m_i=(l+r)>>1;
    b_t(k*2,l,m_i);
    b_t(k*2+1,m_i+1,r);
}
il int s_p(int k,int l,int r,int i,int x,int y){
    if(l>y||r<x)return 0;
    if(l>=x&&r<=y){
        f_i(i,k);
        int n_n=t_r[k].g_g;
        if(t_s[n_n].v_l>=i)return t_s[l_s(n_n)].s_z-1;
        else return t_s[l_s(n_n)].s_z+t_s[n_n].n-1;
    }
    int m_i=(l+r)>>1;
    return s_p(k*2,l,m_i,i,x,y)+s_p(k*2+1,m_i+1,r,i,x,y);
}
il void s_g(int k,int l,int r,int k_l,int v_l){
    d_l(a_a[k_l],k);
    a_d(v_l,k);
    if(l==r&&l==k_l){
        a_a[k_l]=v_l;
        return ;
    }
    int m_i=(l+r)>>1;
    if(m_i>=k_l)s_g(k*2,l,m_i,k_l,v_l);
    else s_g(k*2+1,m_i+1,r,k_l,v_l);
}
il int s_q(int k,int l,int r,int x,int y,int i){ 
    if(l>y||r<x)return -m_a;
    if(l>=x&&r<=y)return t_s[n_t(i,1,k)].v_l;
    int m_i=(l+r)>>1;
    return max(s_q(k*2,l,m_i,x,y,i),s_q(k*2+1,m_i+1,r,x,y,i));
}
il int s_h(int k,int l,int r,int x,int y,int i){
    if(l>y||r<x)return m_a;
    if(l>=x&&r<=y)return t_s[n_t(i,0,k)].v_l;
    int m_i=(l+r)>>1;
    return min(s_h(k*2,l,m_i,x,y,i),s_h(k*2+1,m_i+1,r,x,y,i));
}
il int s_k(int x,int y,int i){
    int l=0,r=1e8,m_i,b_b,a_s;
    while(l<=r){
        m_i=(l+r)>>1;
        b_b=s_p(1,1,n,m_i,x,y)+1;
        if(b_b>i)r=m_i-1;
        else l=m_i+1,a_s=m_i;
    }
    return a_s;
}
int main(){
    n=r_r(),m=r_r();
    int op,l,r,k;
    for(int i=1;i<=n;++i)a_a[i]=r_r();
    b_t(1,1,n);
    for(int i=1;i<=m;++i){
        op=r_r(),l=r_r(),r=r_r();
        if(op==1)k=r_r(),printf("%d\n",s_p(1,1,n,k,l,r)+1);
        if(op==2)k=r_r(),printf("%d\n",s_k(l,r,k));
        if(op==3)s_g(1,1,n,l,r);
        if(op==4)k=r_r(),printf("%d\n",s_q(1,1,n,l,r,k));
        if(op==5)k=r_r(),printf("%d\n",s_h(1,1,n,l,r,k));
    }
    return 0;
}

本题可以用分块的方法写,而且跑的非常快。虽然数据点卡常,但是为了练手,建议写一写,还可以树状数组套平衡树,这里不再赘述,感兴趣可以自己试一试。