CF380E Sereja and Dividing 题解

· · 题解

这是一个只需要写区间乘&区间求和线段树的做法。

先考虑如何对一个序列求答案。显然是把序列升序排序后依次操作。稍微推一下贡献,发现序列中的数从大到小贡献依次为:

\frac{1}{2},\frac{1}{4},\frac{1}{8},\frac{1}{16}\cdots

然后从左往右扫描区间右端点,每次插入一个数,对于一个左端点,假设原来区间中所有数从小到大依次为 a_1,a_2,\cdots,a_t,a_{t+1},\cdots,a_n,假设 b_i 插入到 a_ta_{t+1} 之间,那么 a_1\sim a_t 的贡献系数会乘以 \frac{1}{2},后面的不变,b_i 的贡献系数为 \frac{1}{2^{t+1}}

尝试对所有左端点快速维护这个变换。肯定是放到值域上考虑,插入 b_i 形如对值域上一段前缀乘以 \frac{1}{2}。问题是如何求出当前 b_i 的贡献系数。发现它等于前缀所有数的贡献系数之和(把初始的 0 也算上)。因此维护区间和以及区间贡献系数之和即可。

#include<iostream>
#include<iomanip>
using namespace std;
const int N = 3e5 + 10;
const int M = 1e5 + 10;

int n, m = 1e5;
int a[N];

namespace SegT {
    struct Node {
        double sum1, sum2;
        double tag;
        inline Node(double _sum1 = 0.0, double _sum2 = 0.0, double _tag = 1.0) :
            sum1(_sum1), sum2(_sum2), tag(_tag) {} 
    } tr[4 * M];
    inline int lc(int x) { return x << 1; }
    inline int rc(int x) { return x << 1 | 1; }
    inline void push_up(int p) {
        Node &l = tr[lc(p)], &r = tr[rc(p)];
        tr[p].sum1 = l.sum1 + r.sum1;
        tr[p].sum2 = l.sum2 + r.sum2;
    }
    inline void spread(int p, double tg) { tr[p].tag *= tg; tr[p].sum1 /= tg, tr[p].sum2 /= tg; }
    inline void push_down(int p) {
        if(tr[p].tag != 1.0) spread(lc(p), tr[p].tag), spread(rc(p), tr[p].tag), tr[p].tag = 1.0;
    }
    void mul(int p, int l, int r, int ql, int qr, double v) {
        if(ql <= l && r <= qr) return spread(p, v);
        int mid = (l + r) >> 1; push_down(p);
        if(ql <= mid) mul(lc(p), l, mid, ql, qr, v);
        if(mid < qr) mul(rc(p), mid + 1, r, ql, qr, v);
        push_up(p);
    }
    void insert(int p, int l, int r, int q, double v1, double v2) {
        if(l == r) return tr[p].sum1 += v1, tr[p].sum2 += v2, void();
        int mid = (l + r) >> 1; push_down(p);
        if(q <= mid) insert(lc(p), l, mid, q, v1, v2);
        else insert(rc(p), mid + 1, r, q, v1, v2);
        push_up(p);
    }
    void query(int p, int l, int r, int ql, int qr, double &v1, double &v2) {
        if(ql <= l && r <= qr) return v1 += tr[p].sum1, v2 += tr[p].sum2, void();
        int mid = (l + r) >> 1; push_down(p);
        if(ql <= mid) query(lc(p), l, mid, ql, qr, v1, v2);
        if(mid < qr) query(rc(p), mid + 1, r, ql, qr, v1, v2);
    }
}

double ans;

int main() {

    cin >> n;
    for(int i = 1; i <= n; i++) cin >> a[i];

    for(int i = 1; i <= n; i++) {
        double v1 = 0.0, v2 = 0.0;
        SegT::mul(1, 1, m, 0, a[i], 2);
        SegT::query(1, 1, m, 0, a[i], v1, v2);
        SegT::insert(1, 1, m, a[i], v1 + 0.5, a[i] * (v1 + 0.5));
        SegT::insert(1, 1, m, 0, 0.5, 0.0);
        ans += SegT::tr[1].sum2;
    }

    cout << fixed << setprecision(8) << ans / n / n << '\n';

    return 0;
}