cdq^k:k 维偏序

· · 算法·理论

这是一份我原创的 k 维偏序模板,实现了一个名为 kDPO 的模板结构体,代码如下:

代码实现

template <typename Tp, int K>
struct kDPO {
    const vector<array<Tp, K>> &arr;

    kDPO(const vector<array<Tp, K>> &arr_) : arr(arr_) {

    }

    vector<int> a, cnt, ans, hero;
    array<vector<int>, K - 2> b;
    void cdq(int l, int r) {
        if (l == r) {
            return ;
        }

        int mid = (l + r) >> 1;
        cdq(l, mid);
        cdq(mid + 1, r);

        int i = l, j = mid + 1;
        if constexpr (K == 2) {
            int sum = 0;
            hero.clear();
            while (i <= mid || j <= r) {
                if (j > r || (i <= mid && arr[a[i]][1] <= arr[a[j]][1])) {
                    ++sum;
                    hero.emplace_back(a[i]);
                    ++i;
                } else {
                    ans[a[j]] += sum;
                    hero.emplace_back(a[j]);
                    ++j;
                }
            }
            memcpy(a.data() + l, hero.data(), sizeof(int) * (r - l + 1));
        } else {
            b[0].clear();
            while (i <= mid || j <= r) {
                if (j > r || (i <= mid && arr[a[i]][1] <= arr[a[j]][1])) {
                    b[0].push_back(a[i]);
                    ++i;
                } else {
                    b[0].push_back(a[j] | 1 << 31);
                    ++j;
                }
            }
            for (int i = 0; i < r - l + 1; ++i) {
                a[i + l] = b[0][i] & 2147483647;
            }

            cdq(0, r - l, 0);
        }
    }
    void cdq(int l, int r, int d) {
        if (l == r) {
            return ;
        }

        int mid = (l + r) >> 1;
        cdq(l, mid, d);
        cdq(mid + 1, r, d);

        int i = l, j = mid + 1;
        hero.clear();
        if (d == K - 3) {
            int sum = 0;
            while (i <= mid || j <= r) {
                if (j > r || (i <= mid && arr[b[d][i] & 2147483647][K - 1] <= arr[b[d][j] & 2147483647][K - 1])) {
                    sum += b[d][i] >> 31 ? 0 : cnt[b[d][i] & 2147483647];
                    hero.emplace_back(b[d][i]);
                    ++i;
                } else {
                    ans[b[d][j] & 2147483647] += b[d][j] >> 31 ? sum : 0;
                    hero.emplace_back(b[d][j]);
                    ++j;
                }
            }
            memcpy(b[d].data() + l, hero.data(), sizeof(int) * (r - l + 1));
            return ;
        }
        b[d + 1].clear();
        while (i <= mid || j <= r) {
            if (j > r || (i <= mid && arr[b[d][i] & 2147483647][d + 2] <= arr[b[d][j] & 2147483647][d + 2])) {
                if (!(b[d][i] >> 31)) {
                    b[d + 1].push_back(b[d][i]);
                }
                hero.emplace_back(b[d][i]);
                ++i;
            } else {
                if (b[d][j] >> 31) {
                    b[d + 1].push_back(b[d][j]);
                }
                hero.emplace_back(b[d][j]);
                ++j;
            }
        }
        memcpy(b[d].data() + l, hero.data(), sizeof(int) * (r - l + 1));

        if (!b[d + 1].empty()) {
            cdq(0, (int)b[d + 1].size() - 1, d + 1);
        }
    }

    vector<int> operator () () {
        int n = (int)arr.size();

        vector<int> ord(n);
        iota(ord.begin(), ord.end(), 0);
        sort(ord.begin(), ord.end(), [&](int x, int y) {
            return arr[x] < arr[y];
        });

        vector<pair<int, int>> equa;
        equa.reserve(n);
        cnt.resize(n);
        for (int i = 0; i < n; ) {
            int j = i + 1;
            for ( ; j < n && arr[ord[j]] == arr[ord[i]]; ++j) {
                equa.emplace_back(ord[j], ord[i]);
            }

            a.emplace_back(ord[i]);
            cnt[ord[i]] = j - i;

            i = j;
        }
        ans.resize(n);
        hero.reserve(n);
        for (vector<int> &i : b) {
            i.reserve(n);
        }

        cdq(0, (int)a.size() - 1);

        for (int i = 0; i < n; ++i) {
            ans[i] += cnt[i] - 1;
        }
        for (pair<int, int> i : equa) {
            ans[i.first] = ans[i.second];
        }
        return ans;
    }
};
template <typename Tp>
struct kDPO<Tp, 1> {
    const vector<array<Tp, 1>> &arr;

    kDPO(const vector<array<Tp, 1>> &arr_) : arr(arr_) {

    }

    vector<int> operator () () const {
        int n = (int)arr.size();

        vector<int> ord(n);
        iota(ord.begin(), ord.end(), 0);
        sort(ord.begin(), ord.end(), [&](int x, int y) {
            return arr[x] < arr[y];
        });

        vector<int> ans(n);
        for (int i = 0; i < n; ) {
            int j = i + 1;
            for ( ; j < n && arr[ord[j]] == arr[ord[i]]; ++j);

            for ( ; i < j; ++i) {
                ans[ord[i]] = j - 1;
            }
        }
        return ans;
    }
};

签名讲解

代码有点长,我也写了很久,优化了更久!

我先来讲解一下 kDPO 以及它的 operator () 的签名。

kDPO 的构造函数只接受一个 const vector<array<Tp, k>> &arr,设 narr.size()arr 描述了 n 个有 k 维属性的元素,属性的类型为 Tparr[i][0], arr[i][1], \dots, arr[i][k - 1], 依次描述了 k 维属性的值。

调用 kDPOoperator () 即开始求解 k 维偏序。t它的时间复杂度为 O(nk \log n + n \log^{k-1} n),空间复杂度为 \Theta(nk)。它返回一个长度为 nvector<int>,设其名为 ans,则 ans[i]arr[i] 偏序不同于它的元素的个数。形式化的:

ans[i] = \left(\sum_{j=0}^{n-1} \left[\bigwedge_{l=0}^{k-1} arr[j][l] \leq arr[i][l]\right]\right) - 1

示例

对于经典题目 P3810 【模板】三维偏序(陌上花开):

#include <bits/stdc++.h>
using namespace std;

template <typename Tp, int K>
struct kDPO {

};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, k;
    cin >> n >> k;
    vector<array<int, 3>> arr(n);
    for (array<int, 3> &i : arr) {
        for (int &j : i) {
            cin >> j;
        }
    }
    vector<int> res = kDPO<int, 3>(arr)(), ans(n);
    for (int i : res) {
        ++ans[i];
    }
    for (int i : ans) {
        cout << i << '\n';
    }
    return 0;
}

表现:提交记录

k 为变量

template <typename Tp>
struct kDPO<Tp, 0> {
    const vector<vector<Tp>> &arr;
    int k;

    kDPO(const vector<vector<Tp>> &arr_) : arr(arr_) {
        k = arr[0].size();
    }

    vector<int> a, cnt, ans;
    vector<vector<int>> b;
    vector<int> hero;
    void cdq(int l, int r) {
        if (l == r) {
            return ;
        }

        int mid = (l + r) >> 1;
        cdq(l, mid);
        cdq(mid + 1, r);

        int i = l, j = mid + 1;
        b[0].clear();
        while (i <= mid || j <= r) {
            if (j > r || (i <= mid && arr[a[i]][1] <= arr[a[j]][1])) {
                b[0].push_back(a[i]);
                ++i;
            } else {
                b[0].push_back(a[j] | 1 << 31);
                ++j;
            }
        }
        for (int i = 0; i < r - l + 1; ++i) {
            a[i + l] = b[0][i] & 2147483647;
        }

        cdq(0, r - l, 0);
    }
    void cdq(int l, int r, int d) {
        if (l == r) {
            return ;
        }

        int mid = (l + r) >> 1;
        cdq(l, mid, d);
        cdq(mid + 1, r, d);

        int i = l, j = mid + 1;
        hero.clear();
        if (d == k - 3) {
            int sum = 0;
            while (i <= mid || j <= r) {
                if (j > r || (i <= mid && arr[b[d][i] & 2147483647][k - 1] <= arr[b[d][j] & 2147483647][k - 1])) {
                    sum += b[d][i] >> 31 ? 0 : cnt[b[d][i] & 2147483647];
                    hero.emplace_back(b[d][i]);
                    ++i;
                } else {
                    ans[b[d][j] & 2147483647] += b[d][j] >> 31 ? sum : 0;
                    hero.emplace_back(b[d][j]);
                    ++j;
                }
            }
            memcpy(b[d].data() + l, hero.data(), sizeof(int) * (r - l + 1));
            return ;
        }
        b[d + 1].clear();
        while (i <= mid || j <= r) {
            if (j > r || (i <= mid && arr[b[d][i] & 2147483647][d + 2] <= arr[b[d][j] & 2147483647][d + 2])) {
                if (!(b[d][i] >> 31)) {
                    b[d + 1].push_back(b[d][i]);
                }
                hero.emplace_back(b[d][i]);
                ++i;
            } else {
                if (b[d][j] >> 31) {
                    b[d + 1].push_back(b[d][j]);
                }
                hero.emplace_back(b[d][j]);
                ++j;
            }
        }
        memcpy(b[d].data() + l, hero.data(), sizeof(int) * (r - l + 1));

        if (!b[d + 1].empty()) {
            cdq(0, (int)b[d + 1].size() - 1, d + 1);
        }
    }

    vector<int> operator () () {
        int n = (int)arr.size();
        if (k == 1) {
            vector<array<Tp, 1>> a(n);
            for (int i = 0; i < n; ++i) {
                a[i][0] = arr[i][0];
            }
            return kDPO<Tp, 1>(a)();
        }
        if (k == 2) {
            vector<array<Tp, 2>> a(n);
            for (int i = 0; i < n; ++i) {
                a[i][0] = arr[i][0];
                a[i][1] = arr[i][1];
            }
            return kDPO<Tp, 2>(a)();
        }       

        vector<int> ord(n);
        iota(ord.begin(), ord.end(), 0);
        sort(ord.begin(), ord.end(), [&](int x, int y) {
            return arr[x] < arr[y];
        });

        vector<pair<int, int>> equa;
        equa.reserve(n);
        cnt.resize(n);
        for (int i = 0; i < n; ) {
            int j = i + 1;
            for ( ; j < n && arr[ord[j]] == arr[ord[i]]; ++j) {
                equa.emplace_back(ord[j], ord[i]);
            }

            a.emplace_back(ord[i]);
            cnt[ord[i]] = j - i;

            i = j;
        }
        ans.resize(n);
        hero.reserve(n);
        b.resize(k - 2);
        for (vector<int> &i : b) {
            i.reserve(n);
        }

        cdq(0, (int)a.size() - 1);

        for (int i = 0; i < n; ++i) {
            ans[i] += cnt[i] - 1;
        }
        for (pair<int, int> i : equa) {
            ans[i.first] = ans[i.second];
        }
        return ans;
    }
};