题解:P10884 [COCI 2017/2018 #2] San

· · 题解

折半搜索好题。

看到 n \le 40,想到这有可能需要使用搜索解决,直接 dfs 的复杂度是 O(2^n),无法接受,于是想到了折半搜索。

我们将其平分成两部分,先对前半部分进行 dfs,算出数组 lstsum,分别表示各个序列的结尾的数和该序列选出的金币总和,对于金币总和已经大于 k 的序列,我们先计入答案。

对于后半段,同样进行 dfs,存储下来序列开头 st,结尾 lt 和金币总和 s。那么,对于每一个后半段的序列,其对答案的贡献就是 lst_i \le stsum_i \ge k - s 的前半段的线段的数量。这是一个二维偏序问题,直接暴力找肯定是不行的,我们可以利用 BIT 来完成这个问题,首先,将 lstsum 存入结构体并将其按 lst 从小到大排序,然后,对于每一个后半段的序列,二分查找出最后一个 lst_i \le st 的位置 l,并将其与 k - s 一同存入到一个询问结构体中,将结构体按 l 排序,然后枚举 i,不断加入 sum_i,同时处理 l=i 的询问,统计答案即可,由于值比较大,所以还要离散化一下。

最终复杂度就是 O(2^{\frac{n}{2}}\log 2^{\frac{n}{2}})),最慢点可以 200ms 内跑过。

总体代码比较长,模拟赛时写累死我了

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i, s, t) for (int i = s; i <= t; i++)
#define pre(i, s, t) for (int i = s; i >= t; i--)
#define pb push_back
#define all(x) begin(x), end(x)
const int N = 2e6 + 2;
int n, k, t, m, ans, h[N], g[N], idx, sa[N], la[N], tr[N << 1], V;
vector<int> al;
struct node { int sum, lst; } a[N];
struct Q { int p, v; } q[N]; // 找(1, p)间大于v的数的个数 
bool cmp(node &x, node &y) { return x.lst < y.lst; }
bool cmp2(Q &x, Q &y) { return x.p < y.p; }
void dfs1(int u, int s, int lt) {
    if (u == t + 1) {
        if (s >= k) ans++; // 前面就大于 k 的也要算上
        return a[++idx] = {s, lt}, void();
    }
    if (lt <= h[u]) dfs1(u + 1, s + g[u], h[u]); // 需要大于上一个值才可选
    dfs1(u + 1, s, lt);
}
void dfs2(int u, int s, int st, int lt, int mk) { // mk 用于标记之前是否有选过数,其实就是用来算 st 的
    if (u == n + 1) {
        if (!mk) return;
        int l = upper_bound(la + 1, la + idx + 1, st) - la - 1;
        return q[++m] = {l, k - s}, al.pb(k - s), void();
    }
    if (lt <= h[u]) dfs2(u + 1, s + g[u], !mk ? h[u] : st, h[u], 1);
    dfs2(u + 1, s, st, lt, mk);
}
void add(int p, int d) { for(; p <= V; p += p & -p) tr[p] += d; }
int qry(int p) {
    int res = 0;
    for (; p; p -= p & -p) res += tr[p];
    return res;
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n >> k, t = n / 2; // 取一半
    rep(i, 1, n) cin >> h[i] >> g[i];
    dfs1(1, 0, 0), sort(a + 1, a + idx + 1, cmp);
    rep(i, 1, idx) sa[i] = a[i].sum, la[i] = a[i].lst, al.pb(sa[i]);
    // sa la 与 sum lst 是等价的,后面为了用 STL 所以拆开了 al用于离散
    dfs2(t + 1, 0, 0, 0, 0), sort(all(al));
    // all 是我的宏定义,可以在上面 #define 看意义,还是挺好用的qwq
    al.erase(unique(all(al)), end(al)), V = al.size();
    int j = 1;
    sort(q + 1, q + m + 1, cmp2);
    rep(i, 1, idx) {
        add(lower_bound(all(al), sa[i]) - begin(al) + 1, 1); // 要离散一下
        while (j <= m && q[j].p == i) 
            ans += qry(V) - qry(lower_bound(all(al), q[j].v) - begin(al)), j++;
    }
    cout << ans;
    return 0;
}