权值线段树

· · 算法·理论

基本的线段树节点对应的区间为原数列的下标是这段区间的数的 max/min/sum/...而权值线段树的节点对应区间为值域,节点的值为该区间内的数在整个数列中出现的次数之和

对于一个数列,普通线段树可以维护某个子数组中数的和,而权值线段树可以维护某个区间内的数出现次数,感性理解一下:

维护的值域有时可能过大,所以需要先离散化

建树

几乎和普通线段树一样,上代码:

void build(int k,int l,int r)
{
    if(l == r)
    {
        s[k] = c[l];//c[l]是l出现的次数
        return;
    }
    int mid = (l + r) >> 1;
    build(k << 1,l,mid);
    build(k << 1 | 1,mid + 1,r);
    s[k] = s[k << 1] + s[k << 1 | 1];
}

添加单个数

上代码:

void mdf(int k,int l,int r,int x)
{
    if(l == r)
    {
        s[k]++;//叶子节点维护的就是这个数出现的次数
        return;
    }
    int mid = (l + r) >> 1;
    if(x <= mid)
    {
        mdf(k << 1,l,mid,x);
    }
    if(x > mid)
    {
        mdf(k << 1 | 1,mid + 1,r,x);
    }
    s[k] = s[k << 1] + s[k << 1 | 1];
}

求某数出现的次数

又上代码:

int query1(int k,int l,int r,int x)
{
    if(l == r)
    {
        return s[k];//到了叶子结点就可以直接返回了
    }
    int mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans = query(k << 1,l,mid,x);
    }
    if(x > mid)
    {
        ans = query(k << 1 | 1,mid + 1,r,x);
    }
    return ans;
}

查询区间中数出现的总次数

注意这里 ans+=,不然就覆盖了

int query2(int k,int l,int r,int x,int y)
{
    if(x <= l && r <= y)
    {
        return s[k];//查询完全覆盖当前区间,直接返回
    }
    int mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans += query(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += query(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}

查询整个数列第x小的数

  1. 求出当前节点左右子树的元素个数
  2. 如果 x \le lenl,则说明要找的数在左子树中,所以递归到左子树
  3. 反之,说明要找的数在右子树,递归到右子树,注意将 x-lenl

代码如下:

int query3(int k,int l,int r,int x)
{
    if(l == r)
    {
        return l;//注意这里返回的是s的下标而不是s数组
    }
    int mid = (l + r) >> 1,lenl = s[k << 1],ans = 0;//没错这里lenr用不到(
    if(x <= lenl)
    {
        ans = query3(k << 1,l,mid,x);
    }
    else
    {
        ans = query3(k << 1 | 1,mid + 1,r,x - lenl);
    }
    return ans;
}

例题

例1:P1908

题面

思路是显然的,按顺序加入数并每次查询区间 [a_i+1,maxa] 之间的数出现的总次数,不过值域较大所以需要离散化

注意:离散化要开一个数组辅助

Code:
#include <bits/stdc++.h>
using namespace std;
int n,a1[500005],a[500005];
long long s[2000005];
void mdf(int k,int l,int r,int x)
{
    if(l == r)
    {
        s[k]++;
        return;
    }
    int mid = (l + r) >> 1;
    if(x <= mid)
    {
        mdf(k << 1,l,mid,x);
    }
    if(x > mid)
    {
        mdf(k << 1 | 1,mid + 1,r,x);
    }
    s[k] = s[k << 1] + s[k << 1 | 1];
}
long long query(int k,int l,int r,int x,int y)
{
    if(x <= l && r <= y)
    {
        return s[k];
    }
    int mid = (l + r) >> 1;
    long long ans = 0;
    if(x <= mid)
    {
        ans += query(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += query(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
int main()
{
    scanf("%d",&n);
    for(int i = 1;i <= n;i++)
    {
        scanf("%d",&a[i]);
        a1[i] = a[i];
    }
    sort(a1 + 1,a1 + n + 1);
    int m = unique(a1 + 1,a1 + n + 1) - a1 - 1;
    for(int i = 1;i <= n;i++)
    {
        a[i] = lower_bound(a1 + 1,a1 + m + 1,a[i]) - a1;
    }
    long long ans = 0;
    for(int i = 1;i <= n;i++)
    {
        mdf(1,1,n,a[i]);
        ans += query(1,1,n,a[i] + 1,n);
    }
    printf("%lld",ans);
    return 0;
}

例2:P1637

题面

由于题目要求找三元组,所以不妨从中间入手

易想到,确定中间的数后,找到左边比它小的数的个数和右边比它大的数的个数相乘即为它的贡献

显然可以用权值线段树实现,分别预处理,第一次遍历 1 ~ n,每次找比当前数小的并加入,第二次遍历 n ~ 1,每次找比当前数大的并加入

此题值域只有 10^5,所以无需离散化

死因:

  1. 不需要建树
  2. 查询和加入操作值域并非 [1,n],而是 [1,maxa][1,10^5]
  3. 不支持区间 [x,y]x>y 情况下的查询,所以找比 a_i 小的数时,应确保 a_i \ge 2
Code:
#include <bits/stdc++.h>
using namespace std;
int n,a[30005],c[100005];
long long x[30005],y[30005],s[400005];
void mdf(int k,int l,int r,int x)
{
    if(l == r)
    {
        s[k]++;
        return;
    }
    int mid = (l + r) >> 1;
    if(x <= mid)
    {
        mdf(k << 1,l,mid,x);
    }
    if(x > mid)
    {
        mdf(k << 1 | 1,mid + 1,r,x);
    }
    s[k] = s[k << 1] + s[k << 1 | 1];
}
long long query(int k,int l,int r,int x,int y)
{
    if(x <= l && r <= y)
    {
        return s[k];
    }
    int mid = (l + r) >> 1;
    long long ans = 0;
    if(x <= mid)
    {
        ans += query(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += query(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
int main()
{
    scanf("%d",&n);
    for(int i = 1;i <= n;i++)
    {
        scanf("%d",&a[i]);
        c[a[i]]++;
    }
    for(int i = 1;i <= n;i++)
    {
        mdf(1,1,1e5,a[i]);
        if(a[i] >= 2)
        {
            x[i] = query(1,1,1e5,1,a[i] - 1);
        }
    }
    memset(s,0,sizeof(s));
    for(int i = n;i >= 1;i--)
    {
        mdf(1,1,1e5,a[i]);
        y[i] = query(1,1,1e5,a[i] + 1,1e5);
    }
    long long ans = 0;
    for(int i = 2;i <= n - 1;i++)
    {
        ans += (long long)x[i] * y[i];
    }
    printf("%lld",ans);
    return 0;
}

例3:P6186

题面

权值线段树预处理出原数列 a 每一个数的逆序对数组成的序列 b,所以首先要建一棵树 s1 维护 a每个数出现次数和一段范围内的数出现次数之和

对于两个极其奇怪的操作分别考虑:

(1)交换

分类讨论,操作后 b 数组的变化如下图所示,如果想要维护这个变化,则需要对于 b 再建一棵权值线段树 s2,维护的东西和 s1 一样

这里一定要注意下标的问题,到底对应什么值很容易混

(2)询问

造一组数据进行冒泡排序找规律,发现每经过一轮冒泡排序,所有仍有逆序对的数,它们的逆序对数都会减一,直到 0 就代表排好了

那么有了这个规律,经过 x 轮冒泡排序后,所有初始逆序对数 <x 的直接不用考虑了,因为都变成 0 了,那些 \ge x 的逆序对数每个减去 x,最后剩下的逆序对数之和就是答案

转化成权值线段树表示为:查询 [x,n] 区间内的所有逆序对数之和,减去大小在这个区间内的逆序对个数乘 x 的值

写代码的时候要注意,每棵树都有一个mdf函数和两个query函数

Code:
#include <bits/stdc++.h>
using namespace std;
long long n,q,a[1000005],b[1000005];
struct tree
{
    long long cnt,sum;
}s1[4000005],s2[4000005];
void mdf1(long long k,long long l,long long r,long long x,long long v)
{
    if(l == r)
    {
        s1[k].cnt += v;
        s1[k].sum += v * x;
        return;
    }
    long long mid = (l + r) >> 1;
    if(x <= mid)
    {
        mdf1(k << 1,l,mid,x,v);
    }
    else
    {
        mdf1(k << 1 | 1,mid + 1,r,x,v);
    }
    s1[k].cnt = (s1[k << 1].cnt + s1[k << 1 | 1].cnt);
    s1[k].sum = (s1[k << 1].sum + s1[k << 1 | 1].sum);
}
void mdf2(long long k,long long l,long long r,long long x,long long v)
{
    if(l == r)
    {
        s2[k].cnt += v;
        s2[k].sum += v * x;
        return;
    }
    long long mid = (l + r) >> 1;
    if(x <= mid)
    {
        mdf2(k << 1,l,mid,x,v);
    }
    else
    {
        mdf2(k << 1 | 1,mid + 1,r,x,v);
    }
    s2[k].cnt = (s2[k << 1].cnt + s2[k << 1 | 1].cnt);
    s2[k].sum = (s2[k << 1].sum + s2[k << 1 | 1].sum);
}
long long queryc1(long long k,long long l,long long r,long long x,long long y)
{
    if(x <= l && y >= r)
    {
        return s1[k].cnt;
    }
    long long mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans += queryc1(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += queryc1(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
long long queryc2(long long k,long long l,long long r,long long x,long long y)
{
    if(x <= l && y >= r)
    {
        return s2[k].cnt;
    }
    long long mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans += queryc2(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += queryc2(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
long long querys1(long long k,long long l,long long r,long long x,long long y)
{
    if(x <= l && y >= r)
    {
        return s1[k].sum;
    }
    long long mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans += querys1(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += querys1(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
long long querys2(long long k,long long l,long long r,long long x,long long y)
{
    if(x <= l && y >= r)
    {
        return s2[k].sum;
    }
    long long mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans += querys2(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += querys2(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
int main()
{
    scanf("%lld%lld",&n,&q);
    for(long long i = 1;i <= n;i++)
    {
        scanf("%lld",&a[i]);
    }
    for(long long i = 1;i <= n;i++)
    {
        b[i] = queryc1(1,0,n,a[i] + 1,n);
        mdf1(1,0,n,a[i],1);
        mdf2(1,0,n,b[i],1);
    }
    while(q--)
    {
        long long f,x;
        scanf("%lld%lld",&f,&x);
        if(f == 1)
        {
            if(a[x] < a[x + 1])
            {
                swap(a[x],a[x + 1]);
                swap(b[x],b[x + 1]);
                mdf2(1,0,n,b[x + 1],-1);
                b[x + 1]++;
                mdf2(1,0,n,b[x + 1],1);
            }
            else
            {
                swap(a[x],a[x + 1]);
                swap(b[x],b[x + 1]);
                mdf2(1,0,n,b[x],-1);
                b[x]--;
                mdf2(1,0,n,b[x],1);
            }
        }
        else
        {
            x = min(x,n);
            printf("%lld\n",querys2(1,0,n,x,n) - queryc2(1,0,n,x,n) * x);
        }
    }
    return 0;
}

例4:P3369

题面

没想到平衡树可以用权值线段树实现喵,代码不到 100 行!

  1. 插入 x:mdf函数位置 x 加一
  2. 删除 x:mdf函数位置 x 减一
  3. x$ 的排名:query1函数查询 $[0,x-1]$ 区间内的数出现次数之和 $+1
  4. 排第 x 的数:query2函数,模版见上,每次判断实在左子树还是右子树

因为 |x| \le 10^7不妨将数列里的数都 +10^7 保证非负,注意查询时要再减掉

Code:
#include <bits/stdc++.h>
using namespace std;
int n = 2e7,q,s[80000005];
void mdf(int k,int l,int r,int x,int v)
{
    if(l == r)
    {
        s[k] += v;
        return;
    }
    int mid = (l + r) >> 1;
    if(x <= mid)
    {
        mdf(k << 1,l,mid,x,v);
    }
    else
    {
        mdf(k << 1 | 1,mid + 1,r,x,v);
    }
    s[k] = s[k << 1] + s[k << 1 | 1];
}
int query1(int k,int l,int r,int x,int y)
{
    if(x <= l && r <= y)
    {
        return s[k];
    }
    int mid = (l + r) >> 1,ans = 0;
    if(x <= mid)
    {
        ans += query1(k << 1,l,mid,x,y);
    }
    if(y > mid)
    {
        ans += query1(k << 1 | 1,mid + 1,r,x,y);
    }
    return ans;
}
int query2(int k,int l,int r,int x)
{
    if(l == r)
    {
        return l;
    }
    int mid = (l + r) >> 1,lenl = s[k << 1],ans = 0;
    if(x <= lenl)
    {
        ans = query2(k << 1,l,mid,x);
    }
    else
    {
        ans = query2(k << 1 | 1,mid + 1,r,x - lenl);
    }
    return ans;
}
int main()
{
    scanf("%d",&q);
    int m = 0;
    while(q--)
    {
        int f,x;
        scanf("%d%d",&f,&x);
        if(f == 1)
        {
            x += 10000000;
            mdf(1,0,n,x,1);
            m++;
        }
        else if(f == 2)
        {
            x += 10000000;
            mdf(1,0,n,x,-1);
            m--;
        }
        else if(f == 3)
        {
            x += 10000000;
            printf("%d\n",query1(1,0,n,0,x - 1) + 1);
        }
        else if(f == 4)
        {
            printf("%d\n",query2(1,0,n,x) - 10000000);
        }
        else if(f == 5)
        {
            x += 10000000;
            int y = query1(1,0,n,0,x - 1);
            printf("%d\n",query2(1,0,n,y) - 10000000);
        }
        else
        {
            x += 10000000;
            int y = m - query1(1,0,n,x + 1,n) + 1;
            printf("%d\n",query2(1,0,n,y) - 10000000);
        }
    }
    return 0;
}

OK呀第一种进阶线段树也是学会了好吧喵耶!!!

─=≡Σ(((つ>·ω·)つ