树套树

· · 个人记录

树套树

线段树套线段树

维护多维信息时,可以考虑使用树套树来实现

思路

考虑使用树套树实现二维平面上的单点修改、区间查询,则使用一维线段树作为最外层线段树,最底层的 1n 个节点的子树的就分别代表第 1n 行的线段树,而其父节点就代表两个子节点的子树所在的一片区域

空间复杂度

一般情况下,最外层线段树的每个节点都建立一个完整的子线段树是不太可能的,因为太耗空间,可以采用动态开点的方法,对于一次单点修改,会涉及到外层线段树的 \log n 个节点,而且每个节点的子树也对应的涉及到 \log n 个 节点,所以单次修改所需要的空间消耗是 \log^2 n

时间复杂度

对于每一次询问,我们在外层线段树上需要 \log n 的时间进行查询,而在外层节点所对应的内层线段树上也要进行 \log n 次操作,所以时间复杂度也是 \log n

例题

本题可以使用线段树套线段树的方法进行维护,首先按照第一维进行排序,再用树套树维护第二维和第三维即可

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#define ll int

inline ll read()
{
    ll x=0,f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    return x*f;
}

const ll maxn=2e5+10;
ll n,m;
ll ans[maxn];
struct data 
{
    ll a,b,c,id;
} s[maxn];

struct seg
{
    struct node
    {
        node *ls,*rs;
        ll v;

        node(ll v=0) : v(v)
        {
            ls=rs=NULL;
        }
    };

    node *rt;

    inline void upd(node *&now,ll pos,ll v=1,ll L=1,ll R=m)
    {
        if(!now) now=new node();
        if(L==R)
        {
            now->v+=v;
            return ;
        }
        ll mid=(L+R)>>1;

        if(pos<=mid) upd(now->ls,pos,v,L,mid);
        else upd(now->rs,pos,v,mid+1,R);

        now->v=(now->ls ? now->ls->v : 0)+(now->rs ? now->rs->v : 0);
    }
    inline ll qry(node *now,ll l,ll r,ll L=1,ll R=m)
    {
        if(!now) return 0;
        if(l==L&&r==R) return now->v;

        ll mid=(L+R)>>1;

        if(r<=mid) return qry(now->ls,l,r,L,mid);
        else if(l>mid) qry(now->rs,l,r,mid+1,R);
        return qry(now->ls,l,mid,L,mid)+qry(now->rs,mid+1,r,mid+1,R);
    }
} seg[maxn<<2];

inline bool cmp(data a,data b)
{
    return (a.a==b.a) ? ((a.b==b.b) ? (a.c<b.c) : (a.b<b.b)) : (a.a<b.a);
}

inline void update(ll now,ll posx,ll posy,ll v,ll L=1,ll R=m)
{
    seg[now].upd(seg[now].rt,posy,v);

    if(L==R) return ;

    ll mid=(L+R)>>1;

    if(posx<=mid) update(now<<1,posx,posy,v,L,mid);
    else update(now<<1|1,posx,posy,v,mid+1,R);
}

inline ll query(ll now,ll xxl,ll xxr,ll yyl,ll yyr,ll L=1,ll R=m)
{
    if(xxl==L&&xxr==R) return seg[now].qry(seg[now].rt,yyl,yyr);
    ll mid=(L+R)>>1;
    if(xxr<=mid) return query(now<<1,xxl,xxr,yyl,yyr,L,mid);
    else if(xxl>mid) return query(now<<1|1,xxl,xxr,yyl,yyr,mid+1,R);
    return query(now<<1,xxl,mid,yyl,yyr,L,mid)+query(now<<1|1,mid+1,xxr,yyl,yyr,mid+1,R);
}

int main(void)
{
    n=read(),m=read();

    for(int i=1;i<=n;i++) s[i].a=read(),s[i].b=read(),s[i].c=read();
    std::sort(s+1,s+n+1,cmp);

    ll sum=1;
    for(int i=1;i<=n;i++)
    {
        if(s[i+1].a==s[i].a&&s[i+1].b==s[i].b&&s[i+1].c==s[i].c)
        {
            sum++;
            continue;
        }

        update(1,s[i].b,s[i].c,sum);

        ans[query(1,1,s[i].b,1,s[i].c)]+=sum;

        sum=1;
    }

    for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);

    return 0;
} 

本题可以使用树状数组来代替外层的非动态开点线段树,减小常数

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#define ll int

inline ll read()
{
    ll x=0,f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    return x*f;
}

const ll maxn=2e5+10;
ll n,m;
ll ans[maxn];
struct data 
{
    ll a,b,c,id;
} s[maxn];

struct seg
{
    struct node
    {
        node *ls,*rs;
        ll v;

        node(ll v=0) : v(v)
        {
            ls=rs=NULL;
        }
    };

    node *rt;

    inline void upd(node *&now,ll pos,ll v=1,ll L=1,ll R=maxn)
    {
        if(!now) now=new node();
        if(L==R)
        {
            now->v+=v;
            return ;
        }
        ll mid=(L+R)>>1;

        if(pos<=mid) upd(now->ls,pos,v,L,mid);
        else upd(now->rs,pos,v,mid+1,R);

        now->v=(now->ls ? now->ls->v : 0)+(now->rs ? now->rs->v : 0);
    }
    inline ll qry(node *now,ll l,ll r,ll L=1,ll R=maxn)
    {
        if(!now) return 0;
        if(l==L&&r==R) return now->v;

        ll mid=(L+R)>>1;

        if(r<=mid) return qry(now->ls,l,r,L,mid);
        else if(l>mid) qry(now->rs,l,r,mid+1,R);
        return qry(now->ls,l,mid,L,mid)+qry(now->rs,mid+1,r,mid+1,R);
    }
} seg[maxn<<2];

inline bool cmp(data a,data b)
{
    return (a.a==b.a) ? ((a.b==b.b) ? (a.c<b.c) : (a.b<b.b)) : (a.a<b.a);
}

inline ll lowbit(ll x)
{
    return x & (-x);
}

inline void update(ll posx,ll posy,ll v)
{
    for(int i=posx;i<=m;i+=lowbit(i)) seg[i].upd(seg[i].rt,posy,v,1,m);
} 

inline ll query(ll posx,ll posy)
{
    ll ret=0;
    for(int i=posx;i;i-=lowbit(i)) ret+=seg[i].qry(seg[i].rt,1,posy,1,m);
    return ret;
}

int main(void)
{
    n=read(),m=read();

    for(int i=1;i<=n;i++) s[i].a=read(),s[i].b=read(),s[i].c=read();
    std::sort(s+1,s+n+1,cmp);

    ll sum=1;
    for(int i=1;i<=n;i++)
    {
        if(s[i+1].a==s[i].a&&s[i+1].b==s[i].b&&s[i+1].c==s[i].c)
        {
            sum++;
            continue;
        }

        update(s[i].b,s[i].c,sum);

        ans[query(s[i].b,s[i].c)]+=sum;

        sum=1;
    }

    for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);

    return 0;
} 

线段树套平衡树

思路

最外层使用线段树进行维护,而每一个线段树的节点都是一棵平衡树,平衡树中存储了所有下标在该线段树区间内的节点的信息

例题

P3380 【模板】二逼平衡树(树套树)

本题可以使用线段树套平衡树进行维护

对于操作,我们可以进行维护

  1. 查询 k 在区间内的排名:在线段树上找到区间的对应节点,然后在每个节点内的平衡树内查询对应数的排名并求和即可
  2. 查询排名为 k 的数的大小:注意到排名随数值的大小一定是递增的,所以我们可以二分一个答案,然后对这个值进行 1 中的查询操作,即可对应找到这个值
  3. 单点修改:在线段树中找到这个点所对应的线段树节点,然后在这些节点中将该点删除,并插入新的值
  4. 查询前驱/后继:只需要对每一个区间都进行查询,然后取最大值/最小值进行记录即可

对于每一种操作,在其最外层线段树上都需要 \log n 的时间去维护,而对于每一个节点,在其所维护的平衡树上都需要花费 \log n 的时间去维护,所以对于除了操作 2 的每一次操作的总复杂度为 \log ^2 n ,而操作 2 还有一次二分,所以其复杂度为 \log^3 n

而每个元素会被加入 \log n 个平衡树,所以其空间复杂度为 n \log n

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<cstring>
#define ll long long
#define ld long double

inline ll read()
{
    ll x=0,f=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    return x*f;
}

const ll inf=2147483647;
const ll maxn=5e4+10;
ll n,m,tot;
ll a[maxn];

struct balanced_tree
{
    ll fa,son[2];
    ll v;
    ll siz,cnt;
} tre[(maxn*20)<<4];

struct Segment_tree
{
    ll rot,ls,rs;
} seg[maxn<<2];

inline ll _max(ll a,ll b)
{
    return a<b ? b : a;
}

inline ll _min(ll a,ll b)
{
    return a<b ? a : b;
}

inline void pushup(ll x)
{
    tre[x].siz=tre[tre[x].son[0]].siz+tre[tre[x].son[1]].siz+tre[x].cnt;
}

inline ll newnode(ll v,ll fa)
{
    tre[++tot].siz=1;
    tre[tot].cnt=1;
    tre[tot].v=v;
    tre[tot].fa=fa;
    tre[tot].son[0]=tre[tot].son[1]=0;
    pushup(tot); 
    return tot;
}

inline ll get(ll x)
{
    return tre[tre[x].fa].son[1]==x;
}

inline void connect(ll x,ll y,ll z)
{
    if(x) tre[x].fa=y;
    if(y) tre[y].son[z]=x;
}

inline void rotate(ll x)
{
    ll y=tre[x].fa,z=tre[y].fa;
    ll flag1=get(x),flag2=get(y);

    connect(tre[x].son[flag1^1],y,flag1);
    connect(y,x,flag1^1);
    connect(x,z,flag2);

    pushup(y),pushup(x);
}

inline void Splay(ll x,ll gol,ll node)
{
    for(int fa;(fa=tre[x].fa)!=gol;rotate(x))
    {
        if(tre[fa].fa!=gol)
        {
            rotate(get(x)==get(fa) ? fa : x);
        }
    }
    if(gol==0) seg[node].rot=x;
}

inline void insert(ll x,ll node)
{
    ll nw=seg[node].rot,fa=0;
    if(!nw)
    {
        seg[node].rot=newnode(x,0);
        return ;
    }
    while(nw&&tre[nw].v!=x)
    {
        fa=nw;
        nw=tre[nw].son[tre[nw].v<x];
    }
    if(nw&&tre[nw].v) tre[nw].cnt++;
    else
    {
        nw=newnode(x,fa);
        if(fa) tre[fa].son[tre[fa].v<x]=nw;
    }
    pushup(fa);
    pushup(nw);
    Splay(nw,0,node);
}

inline void find(ll x,ll node)
{
    ll nw=seg[node].rot;
    if(!nw) return ;
    while(tre[nw].son[tre[nw].v<x]&&(tre[nw].v!=x))
    {
        nw=tre[nw].son[tre[nw].v<x];
    }
//  printf("%lld\n",nw);
    Splay(nw,0,node);
}

inline ll pre(ll x,ll node)
{
    find(x,node);
    ll nw=seg[node].rot;
    if(tre[nw].v<x) return nw;
    nw=tre[nw].son[0];
    while(tre[nw].son[1]) nw=tre[nw].son[1];
    return nw;
}

inline ll nxt(ll x,ll node)
{
    find(x,node);
    ll nw=seg[node].rot;
    if(tre[nw].v>x) return nw;
    nw=tre[nw].son[1];
    while(tre[nw].son[0]) nw=tre[nw].son[0];
    return nw;
}

inline void del(ll x,ll node)
{
    ll nw=seg[node].rot;
    ll pr=pre(x,node),nx=nxt(x,node);
    Splay(nx,0,node),Splay(pr,nx,node);
    nw=tre[pr].son[1];
    if(tre[nw].cnt>1)
    {
        tre[nw].cnt--;
        pushup(nw);
        Splay(nw,0,node);
    }
    else tre[pr].son[1]=0;
    pushup(pr);
}

inline void build(ll node,ll l,ll r)
{
    insert(inf,node),insert(-inf,node);
    if(l==r) return ;
    else
    {
        ll mid=(l+r)>>1;
        build(node<<1,l,mid);
        build(node<<1|1,mid+1,r);
    }
}

inline void seg_insert(ll node,ll x,ll l,ll r,ll v)
{
    insert(v,node);

    if(l==r) return ;

    ll mid=(l+r)>>1;
    if(x<=mid) seg_insert(node<<1,x,l,mid,v);
    else seg_insert(node<<1|1,x,mid+1,r,v);
}

inline void seg_update(ll node,ll x,ll l,ll r,ll v)
{
    del(a[x],node),insert(v,node);
    if(l==r&&l==x)
    {
        a[x]=v;
        return ;
    }
    ll mid=(l+r)>>1;
    if(x<=mid) seg_update(node<<1,x,l,mid,v);
    else seg_update(node<<1|1,x,mid+1,r,v);
}

inline bool outofrange(ll l,ll r,ll L,ll R)
{
    return (L>r)||(l>R);
}

inline bool inrange(ll l,ll r,ll L,ll R)
{
    return (L<=l)&&(r<=R);
}

inline ll seg_rk(ll node,ll x,ll l,ll r,ll ql,ll qr)
{
    if(inrange(l,r,ql,qr))
    {
        find(x,node);
        ll nw=seg[node].rot;

        if(tre[nw].v>=x) return tre[tre[nw].son[0]].siz-1;
        else return tre[tre[nw].son[0]].siz+tre[nw].cnt-1;
    }
    else if(outofrange(l,r,ql,qr)) return 0;

    ll mid=(l+r)>>1;
    return seg_rk(node<<1,x,l,mid,ql,qr)+seg_rk(node<<1|1,x,mid+1,r,ql,qr);
}

inline ll seg_kth(ll ql,ll qr,ll x)
{
    ll l=0,r=inf,ans=0,ret=0;
    while(l<=r)
    {
        ll mid=(l+r)>>1;
        ret=seg_rk(1,mid,1,n,ql,qr)+1;
        if(ret>x) r=mid-1;
        else l=mid+1,ans=mid;
    }
    return ans;
}

inline ll seg_pre(ll node,ll x,ll l,ll r,ll ql,ll qr)
{
    if(outofrange(l,r,ql,qr)) return -inf;
    else if(inrange(l,r,ql,qr)) return tre[pre(x,node)].v;

    ll mid=(l+r)>>1;
    return _max(seg_pre(node<<1,x,l,mid,ql,qr),seg_pre(node<<1|1,x,mid+1,r,ql,qr));
}

inline ll seg_nxt(ll node,ll x,ll l,ll r,ll ql,ll qr)
{
    if(outofrange(l,r,ql,qr)) return inf;
    else if(inrange(l,r,ql,qr)) return tre[nxt(x,node)].v;

    ll mid=(l+r)>>1;
    return _min(seg_nxt(node<<1,x,l,mid,ql,qr),seg_nxt(node<<1|1,x,mid+1,r,ql,qr));
}

int main(void)
{
    n=read(),m=read();
    build(1,1,n);

    for(int i=1;i<=n;i++)
    {
        a[i]=read();
        seg_insert(1,i,1,n,a[i]);
    }
    for(int i=1;i<=m;i++)
    {
        ll op=read();
        if(op==1)
        {
            ll l=read(),r=read(),k=read();
            printf("%lld\n",seg_rk(1,k,1,n,l,r)+1); 
        }
        if(op==2)
        {
            ll l=read(),r=read(),k=read();
            printf("%lld\n",seg_kth(l,r,k));
        }
        if(op==3)
        {
            ll pos=read(),k=read();
            seg_update(1,pos,1,n,k);
        }
        if(op==4)
        {
            ll l=read(),r=read(),k=read();
            printf("%lld\n",seg_pre(1,k,1,n,l,r));
        }
        if(op==5)
        {
            ll l=read(),r=read(),k=read();
            printf("%lld\n",seg_nxt(1,k,1,n,l,r));
        }
    }

    return 0;
}
/*
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
*/