线段树上二分

· · 个人记录

线段树上二分

线段树上二分的定义:就是把二分放在线段树上进行。

举个例子: 给定一个严格不下降的序列 a ,和 q 次询问。每次询问输入一个数 k ,输出所有大于等于 k 的数的数量。

我们可以用二分查找,复杂度 O(qlogn) ,代码如下:

#include<bits/stdc++.h>
using namespace std;
int q,n,a[1000005],k;
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    scanf("%d",&q);
    while(q--)
    {
        scanf("%d",&k);
        int l=1,r=n;
        while(l<r)
        {
            int mid=(l+r)/2;
            if(a[mid]<k) l=mid+1;
            else r=mid;
        }
        printf("%d\n",n-l+1);
    }
    return 0;
}

而如果用线段树上二分的话,我们需要记录线段树上每一段区间的最小值和最大值,然后二分找到所有合法的区间,统计答案,复杂度 O(qlogn) ,代码如下:

#include<bits/stdc++.h>
using namespace std;
int q,n,a[1000005],k,ls[4000005],re[4000005];
void build(int node,int l,int r)
{
    if(l==r)
    {
        ls[node]=re[node]=a[l];
        return;
    }
    int mid=(l+r)/2;
    build(node*2,l,mid);
    build(node*2+1,mid+1,r);
    ls[node]=ls[node*2];
    re[node]=re[node*2+1]; 
}
int get(int node,int l,int r,int x)
{
    if(re[node]<x) return 0;
    if(ls[node]>=x) return r-l+1;
    int mid=(l+r)/2,ret=get(node*2+1,mid+1,r,x);
    if(re[node*2]>=x) ret+=get(node*2,l,mid,x);
    return ret;
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    build(1,1,n);
    scanf("%d",&q);
    while(q--)
    {
        scanf("%d",&k);
        printf("%d\n",get(1,1,n,k));
    }
    return 0;
}

其中 ls[node] 记录区间 node 的左端点的值,即最小值;re[node] 记录区间 node 的右端点的值,即最大值。

那么,线段树上二分相较于直接二分有哪些优势呢?

如果你要二分的部分的信息是用一个线段树维护的,那么在线段树上二分时我们可以同时得到相应部分用线段树维护的一些信息;而直接二分则还需要在二分中查询线段树上的一些信息,复杂度更高。

例题

P5579 (https://www.luogu.com.cn/problem/P5579)

代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
int n,m;
ll a[500005],t[2000005],lb1[2000005],lb2[2000005],ls[2000005],re[2000005];
ll get(int node,int l,int r,ll b)
{
    if(re[node]<=b) return 0;
    if(ls[node]>b)
    {
        ll tmp=t[node];
        t[node]=(r-l+1)*b;
        ls[node]=re[node]=lb2[node]=b;
        lb1[node]=0;
        return tmp-t[node];
    }
    int mid=(l+r)/2;
    if(lb2[node]!=-1&&l!=r)
    {
        t[node*2]=(mid-l+1)*lb2[node];
        ls[node*2]=re[node*2]=lb2[node*2]=lb2[node];
        t[node*2+1]=(r-mid)*lb2[node];
        ls[node*2+1]=re[node*2+1]=lb2[node*2+1]=lb2[node];
        lb1[node*2]=lb1[node*2+1]=0;
        lb2[node]=-1;
    }
    if(lb1[node]&&l!=r)
    {
        t[node*2]+=lb1[node]*(a[mid]-a[l-1]);
        ls[node*2]+=lb1[node]*(a[l]-a[l-1]);
        re[node*2]+=lb1[node]*(a[mid]-a[mid-1]);
        lb1[node*2]+=lb1[node];
        t[node*2+1]+=lb1[node]*(a[r]-a[mid]);
        ls[node*2+1]+=lb1[node]*(a[mid+1]-a[mid]);
        re[node*2+1]+=lb1[node]*(a[r]-a[r-1]);
        lb1[node*2+1]+=lb1[node];
        lb1[node]=0;
    }
    ll ret=get(node*2+1,mid+1,r,b);
    if(re[node*2]>b) ret+=get(node*2,l,mid,b);
    t[node]=t[node*2]+t[node*2+1];
    return ret;
}
int main()
{
    memset(lb2,-1,sizeof(lb2));
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
    sort(a+1,a+1+n);
    for(int i=2;i<=n;i++) a[i]+=a[i-1];
    ll last=0,d,b;
    while(m--)
    {
        scanf("%lld%lld",&d,&b);
        t[1]+=(d-last)*a[n];
        ls[1]+=(d-last)*a[1];
        re[1]+=(d-last)*(a[n]-a[n-1]);
        lb1[1]+=d-last;
        last=d;
        printf("%lld\n",get(1,1,n,b));
    }
    return 0;
}

例题

P3224 (https://www.luogu.com.cn/problem/P3224)

代码

#include<bits/stdc++.h>
using namespace std;
int n,m,q,a[100005],fa[100005],pos[100005]={-1},rt[100005],ls[2000005],rs[2000005],sum[2000005],cnt;
void update(int &o,int l,int r,int x)
{
    if(!o) o=++cnt;
    if(l==r)
    {
        sum[o]=1;
        return;
    }
    int mid=(l+r)/2;
    if(x<=mid) update(ls[o],l,mid,x);
    else update(rs[o],mid+1,r,x);
    sum[o]=sum[ls[o]]+sum[rs[o]];
}
void merge(int &x,int y,int l,int r)
{
    if(!x)
    {
        x=y;
        return;
    }
    if(!y) return;
    if(l==r)
    {
        sum[x]+=sum[y];
        sum[y]=0;
        return;
    }
    int mid=(l+r)/2;
    merge(ls[x],ls[y],l,mid);
    merge(rs[x],rs[y],mid+1,r);
    sum[x]=sum[ls[x]]+sum[rs[x]];
}
int findf(int x)
{
    if(fa[x]==x) return x;
    return fa[x]=findf(fa[x]);
}
int ask(int o,int l,int r,int k)
{
    if(k>sum[o]||!o) return 0;
    if(l==r) return l;
    int x=sum[ls[o]],mid=(l+r)/2;
    if(k<=x) return ask(ls[o],l,mid,k);
    return ask(rs[o],mid+1,r,k-x);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) fa[i]=i;
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        pos[a[i]]=i;
        update(rt[i],1,n,a[i]);
    }
    int x,y;
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        x=findf(x);
        y=findf(y);
        if(x==y) continue;
        fa[y]=x;
        merge(rt[x],rt[y],1,n);
    }
    scanf("%d",&q);
    char op;
    while(q--)
    {
        cin>>op;
        scanf("%d%d",&x,&y);
        if(op=='Q')
        {
            x=findf(x);
            printf("%d\n",pos[ask(rt[x],1,n,y)]);
        }
        else
        {
            x=findf(x);
            y=findf(y);
            if(x==y) continue;
            fa[y]=x;
            merge(rt[x],rt[y],1,n);
        }
    }
    return 0;
}