题解 P6278 【[USACO20OPEN]Haircut G】

· · 题解

Haircut

题目传送

前言

这题知道算法也知道是维护逆序对的数量,但是自己的思路没能实现出来

然后去看题解,结果没看懂代码为什么这么写(大概是我思维跟不上这么跳跃)

最后思考了半个小时才终于搞明白,下面会较详细的写明原理

(相当于给其他题解写个注释说明吧)

算法

思路

“对于j= 012 ··· n-1”输出头发的不良度

嗯,看到这种题目一般会想到 逆序思想 :例如 顺着关闭 = 倒着开放

然后这个不良度其实就是逆序对的数量,而维护逆序对一般首选树状数组

初始长度:5 2 3 3 0
当 j = 0:0 0 0 0 0  逆序对:0
当 j = 1:1 1 1 1 0  逆序对:4
当 j = 2:2 2 2 2 0  逆序对:4
当 j = 3: 3 2 3 3 0  逆序对:5
当 j = 4: 4 2 3 3 0  逆序对:7

我们发现,i 的初始长度就是第 i 缕头发最多能长多长

再来看,题目中给定 n 是所有头发的最长长度,那么 n - a[i] 就能够表示第 i 缕头发和最长长度的差距

求出这个差距有什么用?

你想啊,第 i 缕头发长到 a[i] 就不能长了,那它就可能会成为其他能够长的头发的逆序对!

然后呢?我们用 x 表示 n - a[i] ,接着就去查询 x 的前缀,然后把结果累加到 sum[a[i]] 中 (sum[a[i]] 表示这个位置的逆序对)

累加完之后再把当前这个位置加入到树状数组中

上面说到,维护逆序对我们一般使用树状数组,而树状数组和线段树是不能维护 0 的,所以我们需要将所有值都加一,当然,题目中 j 的范围也就变成 1 ~ n 了,以及 x 就等于 n-a[i]+1

代码

#include <bits/stdc++.h>
#define lo long long 
using namespace std;
lo n,x,ans,a[520010],c[520010],sum[520010];

inline void add(lo i,lo k) {
    for(;i<=n+1;i+=i&(-i)) c[i]+=k;
}

inline lo ask(lo i) {
    register lo res=0;
    for(;i;i-=i&(-i)) res+=c[i];
    return res;
}

int main() {
    scanf("%lld",&n);
    for(register lo i=1;i<=n;i++) {
        scanf("%lld",&a[i]);
        a[i]++;
    }
    for(register lo i=1;i<=n;i++) {
        x=n-a[i]+1;
        sum[a[i]]+=ask(x);
        add(x+1,1);  //因为x依旧可能为0,所以还是统一加一
    }
    puts("0");
    for(register lo i=1;i<n;i++) {
        ans+=sum[i];
        printf("%lld\n",ans);
    }
    return 0;
}
#include <bits/stdc++.h>
#define lo long long
using namespace std;
lo n,x,ans,a[520010],sum[520010],num[520010];

inline void build(lo l,lo r,lo bh) {
    if(l==r) {
        sum[bh]=0;
        return ;
    }
    lo mid=(l+r)>>1;
    build(l,mid,bh<<1);
    build(mid+1,r,bh<<1|1);
    sum[bh]=sum[bh<<1]+sum[bh<<1|1];
}

inline void change(lo L,lo l,lo r,lo bh) {
    if(l==r) {
        sum[bh]++;
        return ;
    }
    lo mid=(l+r)>>1;
    if(L<=mid) change(L,l,mid,bh<<1);
    else change(L,mid+1,r,bh<<1|1);
    sum[bh]=sum[bh<<1]+sum[bh<<1|1];
}

inline lo ask(lo L,lo R,lo l,lo r,lo bh) {
    if(L<=l&&r<=R) {
        return sum[bh];
    }
    lo mid=(l+r)>>1,res=0;
    if(L<=mid) res+=ask(L,R,l,mid,bh<<1);
    if(R>mid) res+=ask(L,R,mid+1,r,bh<<1|1);
    return res;
}

int main() {
    scanf("%lld",&n);
    for(register lo i=1;i<=n;i++) {
        scanf("%lld",&a[i]);
        a[i]++;
    }
    build(1,n,1);
    for(register lo i=1;i<=n;i++) {
        x=n-a[i]+1;
        if(x>0) num[a[i]]+=ask(1,x,1,n,1);
        change(x+1,1,n,1);
    }
    puts("0");
    for(register lo i=1;i<n;i++) {
        ans+=num[i];
        printf("%lld\n",ans);
    }
    return 0;
}

希望讲清楚了点