权值线段树
基本的线段树节点对应的区间为原数列的下标,值是这段区间的数的
对于一个数列,普通线段树可以维护某个子数组中数的和,而权值线段树可以维护某个区间内的数出现次数,感性理解一下:
维护的值域有时可能过大,所以需要先离散化
建树
几乎和普通线段树一样,上代码:
void build(int k,int l,int r)
{
if(l == r)
{
s[k] = c[l];//c[l]是l出现的次数
return;
}
int mid = (l + r) >> 1;
build(k << 1,l,mid);
build(k << 1 | 1,mid + 1,r);
s[k] = s[k << 1] + s[k << 1 | 1];
}
添加单个数
上代码:
void mdf(int k,int l,int r,int x)
{
if(l == r)
{
s[k]++;//叶子节点维护的就是这个数出现的次数
return;
}
int mid = (l + r) >> 1;
if(x <= mid)
{
mdf(k << 1,l,mid,x);
}
if(x > mid)
{
mdf(k << 1 | 1,mid + 1,r,x);
}
s[k] = s[k << 1] + s[k << 1 | 1];
}
求某数出现的次数
又上代码:
int query1(int k,int l,int r,int x)
{
if(l == r)
{
return s[k];//到了叶子结点就可以直接返回了
}
int mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans = query(k << 1,l,mid,x);
}
if(x > mid)
{
ans = query(k << 1 | 1,mid + 1,r,x);
}
return ans;
}
查询区间中数出现的总次数
注意这里
int query2(int k,int l,int r,int x,int y)
{
if(x <= l && r <= y)
{
return s[k];//查询完全覆盖当前区间,直接返回
}
int mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans += query(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += query(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
查询整个数列第x小的数
- 求出当前节点左右子树的元素个数
- 如果
x \le lenl ,则说明要找的数在左子树中,所以递归到左子树 - 反之,说明要找的数在右子树,递归到右子树,注意将
x-lenl
代码如下:
int query3(int k,int l,int r,int x)
{
if(l == r)
{
return l;//注意这里返回的是s的下标而不是s数组
}
int mid = (l + r) >> 1,lenl = s[k << 1],ans = 0;//没错这里lenr用不到(
if(x <= lenl)
{
ans = query3(k << 1,l,mid,x);
}
else
{
ans = query3(k << 1 | 1,mid + 1,r,x - lenl);
}
return ans;
}
例题
例1:P1908
题面
思路是显然的,按顺序加入数并每次查询区间
注意:离散化要开一个数组辅助
#include <bits/stdc++.h>
using namespace std;
int n,a1[500005],a[500005];
long long s[2000005];
void mdf(int k,int l,int r,int x)
{
if(l == r)
{
s[k]++;
return;
}
int mid = (l + r) >> 1;
if(x <= mid)
{
mdf(k << 1,l,mid,x);
}
if(x > mid)
{
mdf(k << 1 | 1,mid + 1,r,x);
}
s[k] = s[k << 1] + s[k << 1 | 1];
}
long long query(int k,int l,int r,int x,int y)
{
if(x <= l && r <= y)
{
return s[k];
}
int mid = (l + r) >> 1;
long long ans = 0;
if(x <= mid)
{
ans += query(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += query(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
int main()
{
scanf("%d",&n);
for(int i = 1;i <= n;i++)
{
scanf("%d",&a[i]);
a1[i] = a[i];
}
sort(a1 + 1,a1 + n + 1);
int m = unique(a1 + 1,a1 + n + 1) - a1 - 1;
for(int i = 1;i <= n;i++)
{
a[i] = lower_bound(a1 + 1,a1 + m + 1,a[i]) - a1;
}
long long ans = 0;
for(int i = 1;i <= n;i++)
{
mdf(1,1,n,a[i]);
ans += query(1,1,n,a[i] + 1,n);
}
printf("%lld",ans);
return 0;
}
例2:P1637
题面
由于题目要求找三元组,所以不妨从中间入手
易想到,确定中间的数后,找到左边比它小的数的个数和右边比它大的数的个数,相乘即为它的贡献
显然可以用权值线段树实现,分别预处理,第一次遍历
此题值域只有
死因:
- 不需要建树
- 查询和加入操作值域并非
[1,n] ,而是[1,maxa] ([1,10^5] ) - 不支持区间
[x,y] 在x>y 情况下的查询,所以找比a_i 小的数时,应确保a_i \ge 2
#include <bits/stdc++.h>
using namespace std;
int n,a[30005],c[100005];
long long x[30005],y[30005],s[400005];
void mdf(int k,int l,int r,int x)
{
if(l == r)
{
s[k]++;
return;
}
int mid = (l + r) >> 1;
if(x <= mid)
{
mdf(k << 1,l,mid,x);
}
if(x > mid)
{
mdf(k << 1 | 1,mid + 1,r,x);
}
s[k] = s[k << 1] + s[k << 1 | 1];
}
long long query(int k,int l,int r,int x,int y)
{
if(x <= l && r <= y)
{
return s[k];
}
int mid = (l + r) >> 1;
long long ans = 0;
if(x <= mid)
{
ans += query(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += query(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
int main()
{
scanf("%d",&n);
for(int i = 1;i <= n;i++)
{
scanf("%d",&a[i]);
c[a[i]]++;
}
for(int i = 1;i <= n;i++)
{
mdf(1,1,1e5,a[i]);
if(a[i] >= 2)
{
x[i] = query(1,1,1e5,1,a[i] - 1);
}
}
memset(s,0,sizeof(s));
for(int i = n;i >= 1;i--)
{
mdf(1,1,1e5,a[i]);
y[i] = query(1,1,1e5,a[i] + 1,1e5);
}
long long ans = 0;
for(int i = 2;i <= n - 1;i++)
{
ans += (long long)x[i] * y[i];
}
printf("%lld",ans);
return 0;
}
例3:P6186
题面
用权值线段树预处理出原数列
对于两个极其奇怪的操作分别考虑:
(1)交换
分类讨论,操作后
这里一定要注意下标的问题,到底对应什么值很容易混
(2)询问
造一组数据进行冒泡排序找规律,发现每经过一轮冒泡排序,所有仍有逆序对的数,它们的逆序对数都会减一,直到
那么有了这个规律,经过
转化成权值线段树表示为:查询
写代码的时候要注意,每棵树都有一个mdf函数和两个query函数
#include <bits/stdc++.h>
using namespace std;
long long n,q,a[1000005],b[1000005];
struct tree
{
long long cnt,sum;
}s1[4000005],s2[4000005];
void mdf1(long long k,long long l,long long r,long long x,long long v)
{
if(l == r)
{
s1[k].cnt += v;
s1[k].sum += v * x;
return;
}
long long mid = (l + r) >> 1;
if(x <= mid)
{
mdf1(k << 1,l,mid,x,v);
}
else
{
mdf1(k << 1 | 1,mid + 1,r,x,v);
}
s1[k].cnt = (s1[k << 1].cnt + s1[k << 1 | 1].cnt);
s1[k].sum = (s1[k << 1].sum + s1[k << 1 | 1].sum);
}
void mdf2(long long k,long long l,long long r,long long x,long long v)
{
if(l == r)
{
s2[k].cnt += v;
s2[k].sum += v * x;
return;
}
long long mid = (l + r) >> 1;
if(x <= mid)
{
mdf2(k << 1,l,mid,x,v);
}
else
{
mdf2(k << 1 | 1,mid + 1,r,x,v);
}
s2[k].cnt = (s2[k << 1].cnt + s2[k << 1 | 1].cnt);
s2[k].sum = (s2[k << 1].sum + s2[k << 1 | 1].sum);
}
long long queryc1(long long k,long long l,long long r,long long x,long long y)
{
if(x <= l && y >= r)
{
return s1[k].cnt;
}
long long mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans += queryc1(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += queryc1(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
long long queryc2(long long k,long long l,long long r,long long x,long long y)
{
if(x <= l && y >= r)
{
return s2[k].cnt;
}
long long mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans += queryc2(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += queryc2(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
long long querys1(long long k,long long l,long long r,long long x,long long y)
{
if(x <= l && y >= r)
{
return s1[k].sum;
}
long long mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans += querys1(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += querys1(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
long long querys2(long long k,long long l,long long r,long long x,long long y)
{
if(x <= l && y >= r)
{
return s2[k].sum;
}
long long mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans += querys2(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += querys2(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
int main()
{
scanf("%lld%lld",&n,&q);
for(long long i = 1;i <= n;i++)
{
scanf("%lld",&a[i]);
}
for(long long i = 1;i <= n;i++)
{
b[i] = queryc1(1,0,n,a[i] + 1,n);
mdf1(1,0,n,a[i],1);
mdf2(1,0,n,b[i],1);
}
while(q--)
{
long long f,x;
scanf("%lld%lld",&f,&x);
if(f == 1)
{
if(a[x] < a[x + 1])
{
swap(a[x],a[x + 1]);
swap(b[x],b[x + 1]);
mdf2(1,0,n,b[x + 1],-1);
b[x + 1]++;
mdf2(1,0,n,b[x + 1],1);
}
else
{
swap(a[x],a[x + 1]);
swap(b[x],b[x + 1]);
mdf2(1,0,n,b[x],-1);
b[x]--;
mdf2(1,0,n,b[x],1);
}
}
else
{
x = min(x,n);
printf("%lld\n",querys2(1,0,n,x,n) - queryc2(1,0,n,x,n) * x);
}
}
return 0;
}
例4:P3369
题面
没想到平衡树可以用权值线段树实现喵,代码不到
- 插入
x :mdf函数位置x 加一 - 删除
x :mdf函数位置x 减一 -
x$ 的排名:query1函数查询 $[0,x-1]$ 区间内的数出现次数之和 $+1 - 排第
x 的数:query2函数,模版见上,每次判断实在左子树还是右子树 -
-
因为
#include <bits/stdc++.h>
using namespace std;
int n = 2e7,q,s[80000005];
void mdf(int k,int l,int r,int x,int v)
{
if(l == r)
{
s[k] += v;
return;
}
int mid = (l + r) >> 1;
if(x <= mid)
{
mdf(k << 1,l,mid,x,v);
}
else
{
mdf(k << 1 | 1,mid + 1,r,x,v);
}
s[k] = s[k << 1] + s[k << 1 | 1];
}
int query1(int k,int l,int r,int x,int y)
{
if(x <= l && r <= y)
{
return s[k];
}
int mid = (l + r) >> 1,ans = 0;
if(x <= mid)
{
ans += query1(k << 1,l,mid,x,y);
}
if(y > mid)
{
ans += query1(k << 1 | 1,mid + 1,r,x,y);
}
return ans;
}
int query2(int k,int l,int r,int x)
{
if(l == r)
{
return l;
}
int mid = (l + r) >> 1,lenl = s[k << 1],ans = 0;
if(x <= lenl)
{
ans = query2(k << 1,l,mid,x);
}
else
{
ans = query2(k << 1 | 1,mid + 1,r,x - lenl);
}
return ans;
}
int main()
{
scanf("%d",&q);
int m = 0;
while(q--)
{
int f,x;
scanf("%d%d",&f,&x);
if(f == 1)
{
x += 10000000;
mdf(1,0,n,x,1);
m++;
}
else if(f == 2)
{
x += 10000000;
mdf(1,0,n,x,-1);
m--;
}
else if(f == 3)
{
x += 10000000;
printf("%d\n",query1(1,0,n,0,x - 1) + 1);
}
else if(f == 4)
{
printf("%d\n",query2(1,0,n,x) - 10000000);
}
else if(f == 5)
{
x += 10000000;
int y = query1(1,0,n,0,x - 1);
printf("%d\n",query2(1,0,n,y) - 10000000);
}
else
{
x += 10000000;
int y = m - query1(1,0,n,x + 1,n) + 1;
printf("%d\n",query2(1,0,n,y) - 10000000);
}
}
return 0;
}
OK呀第一种进阶线段树也是学会了好吧喵耶!!!
─=≡Σ(((つ>·ω·)つ