树状数组怎么做三维偏序?

· · 题解

首先你需要会二维分块。

二维分块就是你首先将每一维坐标离散化为排列,此时 x, y 是唯一对应的,这样你可以暴力统计一个矩形散块内的答案,复杂度即为其最短边的长度。

我们考虑使用二维树状数组处理整块,维护的二维平面上每个位置的值是对应整块内所有数之和,设块长为 B,空间复杂度即为 O(n + (n/B)^2)

然后你修改的时候在树状数组上改对应整块,同时改单点,查询的时候在树状数组上查对应前缀整块,同时暴力统计散块,这样时间复杂度是 O(n \log^2(n/B) + nB) 的。

如果取 B = \log^2 n,那么时间复杂度为 O(n \log^2 n),空间复杂度为 O(n^2/\log^4 n)

https://www.luogu.com.cn/record/265646886

#include <iostream>
#include <utility>
#include <vector>
using namespace std;
const int N = 1e5 + 5, K = 5e5 + 5, B = 200, S = N / B + 5;
int n, k, a[N], b[N], c[N], la[K], ra[K], lb[K], rb[K], ab[N], ba[N], bel[N], va[N], vb[N], t[S][S],
    ans[N];
vector<pair<int, int>> vu[K], vq[K];
void update_(int x, int y) {
    for (int i = x; i <= bel[n]; i += i & -i)
        for (int j = y; j <= bel[n]; j += j & -j) ++t[i][j];
}
int query_(int x, int y) {
    int ret = 0;
    for (int i = x; i >= 1; i -= i & -i)
        for (int j = y; j >= 1; j -= j & -j) ret += t[i][j];
    return ret;
}
void update(int x, int y) { update_(bel[x], bel[y]), ++va[x], ++vb[y]; }
int query(int x, int y) {
    int ret = query_(bel[x] - 1, bel[y] - 1);
    for (int i = B * (bel[x] - 1) + 1; i <= x; i++)
        if (ab[i] <= y) ret += va[i];
    for (int i = B * (bel[y] - 1) + 1; i <= y; i++)
        if (ba[i] <= B * (bel[x] - 1)) ret += vb[i];
    return ret;
}
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> k;
    for (int i = 1; i <= n; i++) cin >> a[i] >> b[i] >> c[i], ++ra[a[i]], ++rb[b[i]];
    for (int i = 1; i <= k; i++)
        la[i] = ra[i - 1] + 1, ra[i] += ra[i - 1], lb[i] = rb[i - 1] + 1, rb[i] += rb[i - 1];
    for (int i = 1; i <= n; i++) {
        int x = la[a[i]]++, y = lb[b[i]]++;
        ab[x] = y, ba[y] = x;
        vu[c[i]].emplace_back(x, y);
        vq[c[i]].emplace_back(ra[a[i]], rb[b[i]]);
    }
    for (int i = 1; i <= n; i++) bel[i] = (i - 1) / B + 1;
    for (int i = 1; i <= k; i++) {
        for (auto [x, y] : vu[i]) update(x, y);
        for (auto [x, y] : vq[i]) ++ans[query(x, y) - 1];
    }
    for (int i = 0; i < n; i++) cout << ans[i] << '\n';
}