浅谈维护叶子节点及其所构成的虚树的一类线段树

· · 休闲·娱乐

早年写的,现在来投娱乐区了/fendou。

LCA

最基础的操作,目的是求解两个区间在线段树上的 LCA。

首先不难发现,两个区间的 LCA 等价于两个区间中任意取两点的 LCA,因为:

LCA(u,v) = LCA(son_u,son_v)

所以问题转化成求两个点的 LCA。

我们考虑记录下每个节点的编号,然后通过异或完求最低位的 1 的方式求出 LCA 的深度,然后记录下 LCA 右端点的编号。

在回收节点的时候可以释放储存右端点编号的空间,但是这里为了方便就不这样做了。

int w(int x){
    int l=1,r=(1<<18),v=1,res=0;
    while(l!=r){
        int mid=(l+r)>>1;
        if(x<=mid){
            r=mid;
        }
        else{
            l=mid+1;
            res+=v;
        }
        v<<=1;
    }
    return res;
}
inline pair<int,int> LCA(int u,int v){
    if(val[u]==0) val[u]=w(u);
    if(val[v]==0) val[v]=w(v);
    int f=val[u]^val[v];
    f=f&(-f);
    int len=1<<(18-lg[f]);
    int pos=(u-1)/len+1;
    val[(pos-1)*len+1]=(f-1)&val[u];
    return make_pair((pos-1)*len+1,pos*len);
}

这里只会对每个叶子节点执行一次求解编号的操作,故时间复杂度是 O(n \log V) 的。

注意到这里记录右端点的编号可能会带来一定的空间常数,所以还有一种时间 O(\log \log n) 但是空间常数小的写法:

bool check(int u,int v,int len){
    return ((u-1)/(1<<len)+1==(v-1)/(1<<len)+1);
}
inline pair<int,int> LCA(int u,int v){
    int l=-1,r=30;
    while(l+1<r){
        int mid=(l+r)/2;
        if(check(u,v,mid)==true) r=mid;
        else l=mid;
    }
    int posu=(u-1)/(1<<r)+1;
    return make_pair((posu-1)*(1<<r)+1,posu*(1<<r));
}

可以证明在无需分裂合并的情况下一次操作只会求一次 LCA 因此在当平衡树使用时上面的写法更为合适

ADD

目的是单点修改,考虑在线段树上遍历,根据节点情况的不同大力分类讨论,建出新点与原来某个节点的 LCA。

inline void add(int x,int pos,int v){
    int mid=(tree[x].l+tree[x].r)>>1;
    if(pos<=mid){
        if(tree[x].ls==0){
            int y=clone();
            tree[x].ls=y;
            tree[y].sum+=v;
            tree[y].l=tree[y].r=pos;
            pushup(x);
            return ;
        }
        else{
            if(tree[tree[x].ls].l==tree[tree[x].ls].r){
                if(tree[tree[x].ls].l==pos){
                    tree[tree[x].ls].sum+=v;
                    pushup(x);
                    return ;
                }
                pair<int,int> lca=LCA(pos,tree[tree[x].ls].l);
                int y=clone();
                tree[y].l=lca.first;
                tree[y].r=lca.second;
                int A=y;
                int B=tree[x].ls;
                int z=clone();
                tree[z].sum+=v;
                tree[z].l=tree[z].r=pos;
                int C=z;
                tree[x].ls=A;
                if(tree[B].l>tree[C].l) swap(B,C);
                tree[A].ls=B;
                tree[A].rs=C;
                pushup(A);
                pushup(x);
                return ;
            }
            else{
                if(pos>tree[tree[x].ls].r||pos<tree[tree[x].ls].l){
                    pair<int,int> lca=LCA(pos,tree[tree[x].ls].l);
                    int y=clone();
                    tree[y].l=lca.first;
                    tree[y].r=lca.second;
                    int A=y;
                    int B=tree[x].ls;
                    int z=clone();
                    tree[z].sum+=v;
                    tree[z].l=tree[z].r=pos;
                    int C=z;
                    tree[x].ls=A;
                    if(tree[B].l>tree[C].l) swap(B,C);
                    tree[A].ls=B;
                    tree[A].rs=C;
                    pushup(A);
                    pushup(x);
                }
                else{
                    add(tree[x].ls,pos,v);
                    pushup(x);
                    return ;
                }
            }
        }
    }
    else{
        if(tree[x].rs==0){
            int y=clone();
            tree[x].rs=y;
            tree[y].sum+=v;
            tree[y].l=tree[y].r=pos;
            pushup(x);
            return ;
        }
        else{
            if(tree[tree[x].rs].l==tree[tree[x].rs].r){
                if(tree[tree[x].rs].r==pos){
                    tree[tree[x].rs].sum+=v;
                    pushup(x);
                    return ;
                }
                pair<int,int> lca=LCA(pos,tree[tree[x].rs].l);
                int y=clone();
                tree[y].l=lca.first;
                tree[y].r=lca.second;
                int A=y;
                int B=tree[x].rs;
                int z=clone();
                tree[z].sum+=v;
                tree[z].l=tree[z].r=pos;
                int C=z;
                tree[x].rs=A;
                if(tree[B].l>tree[C].l) swap(B,C);
                tree[A].ls=B;
                tree[A].rs=C;
                pushup(A);
                pushup(x);
                return ;
            }
            else{
                if(pos<tree[tree[x].rs].l||pos>tree[tree[x].rs].r){
                    pair<int,int> lca=LCA(pos,tree[tree[x].rs].l);
                    int y=clone();
                    tree[y].l=lca.first;
                    tree[y].r=lca.second;
                    int A=y;
                    int B=tree[x].rs;
                    int z=clone();
                    tree[z].sum+=v;
                    tree[z].l=tree[z].r=pos;
                    int C=z;
                    tree[x].rs=A;
                    if(tree[C].l>tree[B].l) swap(C,B);
                    tree[A].ls=C;
                    tree[A].rs=B;
                    pushup(A);
                    pushup(x);
                }
                else{
                    add(tree[x].rs,pos,v);
                    pushup(x);
                    return ;
                }
            }
        }
    }
}

不难发现这个函数中求解一次 LCA 后必定返回,故应证了上面的话。

QUERY

查询区间和,这一部分与普通线段树无异。

inline int query(int x,int l,int r){
    int lt=tree[x].l;
    int rt=tree[x].r;
    if(x==0) return 0;
    if(rt<l||r<lt){
        return 0;
    }
    if(l<=lt&&rt<=r){
        return tree[x].sum;
    }
    int res=0,mid=(lt+rt)>>1;
    res+=query(tree[x].ls,l,r);
    res+=query(tree[x].rs,l,r);     
    return res;
}

KTH

查询第 k 大,用来作为线段树上二分的一个模板,与普通线段树也无异。

inline int kth(int x,int k){
    if(tree[x].l==tree[x].r) return tree[x].l;
    if(k<=tree[tree[x].ls].sum) return kth(tree[x].ls,k);
    else return kth(tree[x].rs,k-tree[tree[x].ls].sum);
} 

MERGE

合并两棵线段树,这一部分需要根据合并的两棵线段树所管辖的区间来决定是否需要新建节点或者更改所管辖的区间来保证虚树上只存在叶子结点的 LCA 一性质。

int merge(int a,int b){
    if(a==0||b==0) return a+b;
    if((tree[a].r-tree[a].l+1)<(tree[b].r-tree[b].l+1)) swap(a,b);
    if(tree[a].l==tree[b].l&&tree[a].r==tree[b].r){
        if(tree[a].l==tree[a].r){
            tree[a].sum+=tree[b].sum;
            brush.push(b);
            return a;
        }
        int L=tree[b].ls,R=tree[b].rs;
        brush.push(b);
        tree[a].ls=merge(tree[a].ls,L);
        tree[a].rs=merge(tree[a].rs,R);
        pushup(a);
        return a;
    }   
    if(tree[a].l<=tree[b].l&&tree[b].r<=tree[a].r){
        int mid=(tree[a].l+tree[a].r)>>1;
        if(tree[b].l<=mid) tree[a].ls=merge(tree[a].ls,b);
        else tree[a].rs=merge(tree[a].rs,b);
        pair<int,int> lca=LCA(tree[tree[a].ls].l,tree[tree[a].rs].l);
        tree[a].l=lca.first,tree[a].r=lca.second;
        pushup(a);
        return a;
    }
    if(tree[a].l>tree[b].l) swap(a,b);
    if(tree[a].r<tree[b].l){
        pair<int,int> lca=LCA(tree[a].l,tree[b].l);
        int y=clone();
        tree[y].l=lca.first;
        tree[y].r=lca.second;
        tree[y].ls=a;
        tree[y].rs=b;
        pushup(y);      
        return y;
    }
    else{
        int L=tree[b].ls,R=tree[b].rs;
        brush.push(b);
        tree[a].ls=merge(tree[a].ls,L);
        tree[a].rs=merge(tree[a].rs,R);
        pushup(a);
        return a;
    }
}

在这个过程中 LCA 的求解至多会进行 \log V 次,因为每一层的合并都有可能需要调整虚树节点。

SPLIT

这一部分需要递归不断处理,将一棵子树的左右儿子拆开再合并。

inline void split(int &x,int &y,int l,int r){
    if(!x) return ;
    int lt=tree[x].l,rt=tree[x].r;
    if(rt<l||r<lt) return ;
    if(l<=lt&&rt<=r){
        y=x;
        x=0;
        return ;
    }
    if(!y) y=clone();
    split(tree[x].ls,tree[y].ls,l,r);
    split(tree[x].rs,tree[y].rs,l,r);
    if(tree[y].ls==0&&tree[y].rs==0){
        brush.push(y);
        y=0;
    }
    else if(tree[y].ls==0){
        brush.push(y);
        y=tree[y].rs;
    }
    else if(tree[y].rs==0){
        brush.push(y);
        y=tree[y].ls;
    }
    else{
        pair<int,int> lca=LCA(tree[tree[y].ls].l,tree[tree[y].rs].l);
        tree[y].l=lca.first,tree[y].r=lca.second;
        pushup(y);
    }
    if(tree[x].ls==0&&tree[x].rs==0){
        brush.push(x);
        x=0;
    }
    else if(tree[x].ls==0){
        brush.push(x);
        x=tree[x].rs;
    }
    else if(tree[x].rs==0){
        brush.push(x);
        x=tree[x].ls;
    }
    else{
        pair<int,int> lca=LCA(tree[tree[x].ls].l,tree[tree[x].rs].l);
        tree[x].l=lca.first,tree[x].r=lca.second;
        pushup(x);
    }
    return ;
}

同样需要调用 \log V 次 LCA 求解保证虚树的形态。

MAINTAIN

再前面的操作中会使一棵子树的根节点不断变化,但是我们插入查询时为了方便需要保证根节点管辖的区间是 [1,top],所以需要此函数强制定义一个根节点。

inline void maintain(int &x){    
    if(tree[x].l!=1||tree[x].r!=(1<<20)){
        int y=clone();
        tree[y].l=1,tree[y].r=(1<<20);
        int mid=(tree[y].l+tree[y].r)>>1;
        if(!x){
            x=y;
            return ;
        }
        if(tree[x].r<=mid){
            tree[y].ls=x;
        }
        else{
            tree[y].rs=x;
        }
        pushup(y);
        x=y;
    }
    return ;
}

代码整合

以下是【模板】线段树分裂的代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 4e5+114;
int tot;
stack<int> brush;
struct Node{
    int ls,rs;
    int sum;
    int l,r;
}tree[maxn<<1];
int clone(){
    int New;
    if(brush.size()>0){
        New=brush.top();
        brush.pop();
    }
    else{
        New=++tot;
    }
    tree[New].l=tree[New].r=tree[New].sum=tree[New].ls=tree[New].rs=0;
    return New;
}
int root[maxn],val[maxn],lg[maxn];
inline int w(int x){
    int l=1,r=(1<<18),v=1,res=0;
    while(l!=r){
        int mid=(l+r)>>1;
        if(x<=mid){
            r=mid;
        }
        else{
            l=mid+1;
            res+=v;
        }
        v<<=1;
    }
    return res;
}
inline pair<int,int> LCA(int u,int v){
    if(val[u]==0) val[u]=w(u);
    if(val[v]==0) val[v]=w(v);
    int f=val[u]^val[v];
    f=f&(-f);
    int len=1<<(18-lg[f]);
    int pos=(u-1)/len+1;
    val[(pos-1)*len+1]=(f-1)&val[u];
    return make_pair((pos-1)*len+1,pos*len);
}
inline void pushup(int x){
    tree[x].sum=tree[tree[x].ls].sum+tree[tree[x].rs].sum;
}
inline void add(int x,int pos,int v){
    int mid=(tree[x].l+tree[x].r)>>1;
    if(pos<=mid){
        if(tree[x].ls==0){
            int y=clone();
            tree[x].ls=y;
            tree[y].sum+=v;
            tree[y].l=tree[y].r=pos;
            pushup(x);
            return ;
        }
        else{
            if(tree[tree[x].ls].l==tree[tree[x].ls].r){
                if(tree[tree[x].ls].l==pos){
                    tree[tree[x].ls].sum+=v;
                    pushup(x);
                    return ;
                }
                pair<int,int> lca=LCA(pos,tree[tree[x].ls].l);
                int y=clone();
                tree[y].l=lca.first;
                tree[y].r=lca.second;
                int A=y;
                int B=tree[x].ls;
                int z=clone();
                tree[z].sum+=v;
                tree[z].l=tree[z].r=pos;
                int C=z;
                tree[x].ls=A;
                if(tree[B].l>tree[C].l) swap(B,C);
                tree[A].ls=B;
                tree[A].rs=C;
                pushup(A);
                pushup(x);
                return ;
            }
            else{
                if(pos>tree[tree[x].ls].r||pos<tree[tree[x].ls].l){
                    pair<int,int> lca=LCA(pos,tree[tree[x].ls].l);
                    int y=clone();
                    tree[y].l=lca.first;
                    tree[y].r=lca.second;
                    int A=y;
                    int B=tree[x].ls;
                    int z=clone();
                    tree[z].sum+=v;
                    tree[z].l=tree[z].r=pos;
                    int C=z;
                    tree[x].ls=A;
                    if(tree[B].l>tree[C].l) swap(B,C);
                    tree[A].ls=B;
                    tree[A].rs=C;
                    pushup(A);
                    pushup(x);
                }
                else{
                    add(tree[x].ls,pos,v);
                    pushup(x);
                    return ;
                }
            }
        }
    }
    else{
        if(tree[x].rs==0){
            int y=clone();
            tree[x].rs=y;
            tree[y].sum+=v;
            tree[y].l=tree[y].r=pos;
            pushup(x);
            return ;
        }
        else{
            if(tree[tree[x].rs].l==tree[tree[x].rs].r){
                if(tree[tree[x].rs].r==pos){
                    tree[tree[x].rs].sum+=v;
                    pushup(x);
                    return ;
                }
                pair<int,int> lca=LCA(pos,tree[tree[x].rs].l);
                int y=clone();
                tree[y].l=lca.first;
                tree[y].r=lca.second;
                int A=y;
                int B=tree[x].rs;
                int z=clone();
                tree[z].sum+=v;
                tree[z].l=tree[z].r=pos;
                int C=z;
                tree[x].rs=A;
                if(tree[B].l>tree[C].l) swap(B,C);
                tree[A].ls=B;
                tree[A].rs=C;
                pushup(A);
                pushup(x);
                return ;
            }
            else{
                if(pos<tree[tree[x].rs].l||pos>tree[tree[x].rs].r){
                    pair<int,int> lca=LCA(pos,tree[tree[x].rs].l);
                    int y=clone();
                    tree[y].l=lca.first;
                    tree[y].r=lca.second;
                    int A=y;
                    int B=tree[x].rs;
                    int z=clone();
                    tree[z].sum+=v;
                    tree[z].l=tree[z].r=pos;
                    int C=z;
                    tree[x].rs=A;
                    if(tree[C].l>tree[B].l) swap(C,B);
                    tree[A].ls=C;
                    tree[A].rs=B;
                    pushup(A);
                    pushup(x);
                }
                else{
                    add(tree[x].rs,pos,v);
                    pushup(x);
                    return ;
                }
            }
        }
    }
}
inline int query(int x,int l,int r){
    int lt=tree[x].l;
    int rt=tree[x].r;
    if(x==0) return 0;
    if(rt<l||r<lt){
        return 0;
    }
    if(l<=lt&&rt<=r){
        return tree[x].sum;
    }
    int res=0,mid=(lt+rt)>>1;
    res+=query(tree[x].ls,l,r);
    res+=query(tree[x].rs,l,r);     
    return res;
}
int merge(int a,int b){
    if(a==0||b==0) return a+b;
    if((tree[a].r-tree[a].l+1)<(tree[b].r-tree[b].l+1)) swap(a,b);
    if(tree[a].l==tree[b].l&&tree[a].r==tree[b].r){
        if(tree[a].l==tree[a].r){
            tree[a].sum+=tree[b].sum;
            brush.push(b);
            return a;
        }
        int L=tree[b].ls,R=tree[b].rs;
        brush.push(b);
        tree[a].ls=merge(tree[a].ls,L);
        tree[a].rs=merge(tree[a].rs,R);
        pushup(a);
        return a;
    }   
    if(tree[a].l<=tree[b].l&&tree[b].r<=tree[a].r){
        int mid=(tree[a].l+tree[a].r)>>1;
        if(tree[b].l<=mid) tree[a].ls=merge(tree[a].ls,b);
        else tree[a].rs=merge(tree[a].rs,b);
        pair<int,int> lca=LCA(tree[tree[a].ls].l,tree[tree[a].rs].l);
        tree[a].l=lca.first,tree[a].r=lca.second;
        pushup(a);
        return a;
    }
    if(tree[a].l>tree[b].l) swap(a,b);
    if(tree[a].r<tree[b].l){
        pair<int,int> lca=LCA(tree[a].l,tree[b].l);
        int y=clone();
        tree[y].l=lca.first;
        tree[y].r=lca.second;
        tree[y].ls=a;
        tree[y].rs=b;
        pushup(y);      
        return y;
    }
    else{
        int L=tree[b].ls,R=tree[b].rs;
        brush.push(b);
        tree[a].ls=merge(tree[a].ls,L);
        tree[a].rs=merge(tree[a].rs,R);
        pushup(a);
        return a;
    }
}
inline void maintain(int &x){    
    if(tree[x].l!=1||tree[x].r!=(1<<18)){
        int y=clone();
        tree[y].l=1,tree[y].r=(1<<18);
        int mid=(tree[y].l+tree[y].r)>>1;
        if(!x){
            x=y;
            return ;
        }
        if(tree[x].r<=mid){
            tree[y].ls=x;
        }
        else{
            tree[y].rs=x;
        }
        pushup(y);
        x=y;
    }
    return ;
}
inline void split(int &x,int &y,int l,int r){
    if(!x) return ;
    int lt=tree[x].l,rt=tree[x].r;
    if(rt<l||r<lt) return ;
    if(l<=lt&&rt<=r){
        y=x;
        x=0;
        return ;
    }
    if(!y) y=clone();
    split(tree[x].ls,tree[y].ls,l,r);
    split(tree[x].rs,tree[y].rs,l,r);
    if(tree[y].ls==0&&tree[y].rs==0){
        brush.push(y);
        y=0;
    }
    else if(tree[y].ls==0){
        brush.push(y);
        y=tree[y].rs;
    }
    else if(tree[y].rs==0){
        brush.push(y);
        y=tree[y].ls;
    }
    else{
        pair<int,int> lca=LCA(tree[tree[y].ls].l,tree[tree[y].rs].l);
        tree[y].l=lca.first,tree[y].r=lca.second;
        pushup(y);
    }
    if(tree[x].ls==0&&tree[x].rs==0){
        brush.push(x);
        x=0;
    }
    else if(tree[x].ls==0){
        brush.push(x);
        x=tree[x].rs;
    }
    else if(tree[x].rs==0){
        brush.push(x);
        x=tree[x].ls;
    }
    else{
        pair<int,int> lca=LCA(tree[tree[x].ls].l,tree[tree[x].rs].l);
        tree[x].l=lca.first,tree[x].r=lca.second;
        pushup(x);
    }
    return ;
}
inline int kth(int x,int k){
    if(tree[x].l==tree[x].r) return tree[x].l;
    if(k<=tree[tree[x].ls].sum) return kth(tree[x].ls,k);
    else return kth(tree[x].rs,k-tree[tree[x].ls].sum);
} 
void init(int pos){
    root[pos]=clone();
    tree[root[pos]].l=1;
    tree[root[pos]].r=(1<<18);
}
int n,m;
int cnt;
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
for(int i=0;i<=18;i++) lg[1<<i]=i;
cin>>n>>m;
cnt++;
init(cnt);
for(int i=1;i<=n;i++){
    int x;
    cin>>x;
    add(root[cnt],i,x);
}
while(m--){
    int opt;
    cin>>opt;
    if(opt==0){
        int x,y,z;
        cin>>x>>y>>z;
        cnt++;
        split(root[x],root[cnt],y,z);
        maintain(root[x]);
        maintain(root[cnt]);
    }
    else if(opt==1){
        int x,y;
        cin>>x>>y;
        root[x]=merge(root[x],root[y]);
        maintain(root[x]);
    }
    else if(opt==2){
        int x,y,z;
        cin>>x>>y>>z;
        add(root[x],z,y);
    }
    else if(opt==3){
        int x,y,z;
        cin>>x>>y>>z;
        cout<<query(root[x],y,z)<<'\n';
    } 
    else{
        int x,y;
        cin>>x>>y;
        if(tree[root[x]].sum<y){
            cout<<"-1\n";
        }
        else{
            cout<<kth(root[x],y)<<'\n';
        }
    }
}
return 0;
}

此代码开启 O2 优化的情况下最大点用时 363ms 27.16MB。