你所不了解的算法——BFPRT 算法
cancan123456
·
·
个人记录
给定一个长 n 的无序数组和一个 1\sim n 之间的整数 k,求出该无序数组中第 k 大的元素。
—— Top k 问题。
算法一
我会排序!
## 算法二
我会排序!
进行 $k$ 轮冒泡排序,这样就得到了前 $k$ 小。
时间复杂度 $O(nk)$,在 $k$ 小的离谱的时候有用。
## 算法三
我会使用堆!
维护一个大小为 $k$ 的堆,先把前 $k$ 个元素入堆,然后遍历后面的元素,如果这个元素比堆顶小,弹掉堆顶并插入这个元素,遍历完成后依次弹掉堆顶,最后一个弹掉的元素就是第 $k$ 小。
时间复杂度 $O(n\log k)$。
## 算法四
我知道快速排序的原理!
假设我们要求一个数组 $a$ 的第 $k$ 大,我们从 $a$ 中随机选择一个分界点 $x$,将小于 $x$ 的元素放在左边,大于 $x$ 的元素放在右边,$x$ 放在中间。
如果左边的元素个数大于 $k$,递归左边的元素,且 $k$ 不变。
如果左边的元素个数等于 $k$,返回 $x$。
如果左边的元素个数大于 $k$,递归右边的元素,这时 $k$ 要减去左边的元素再减 $1$(要排除 $x$)。
期望时间复杂度 $O(n)$。
## 算法五
使用 BFPRT 算法(也叫 Median of Medians 算法)。
该算法由 Blum、Floyd、Pratt、Rivest、Tarjan 发明,再此向 Tarjan 爷爷表示膜拜。
先把 $n$ 个数 $5$ 个一组分成 $\left\lceil\dfrac n5\right\rceil$ 组,其中只有最后一组有 $n\bmod5$ 个数。
然后对分成 $\left\lceil\dfrac n5\right\rceil$ 组数,暴力计算出这些数组的中位数,得到了一个长 $\left\lceil\dfrac n5\right\rceil$ 的数组,然后**调用 BFPRT 算法计算出中位数**(注意这里必须调用 BFPRT,不然复杂度分析会失效)。
得到这 $\left\lceil\dfrac n5\right\rceil$ 个数的中位数后,将这个数作为 $x$ 进行划分,然后根据原来的算法递归调用。
我们发现,这个算法的奇怪之处就是 $x$ 的选取,显然,选择的 $x$ 越接近中间,那么算法的效率就越高。
我们将 $a$ 的元素五个一列画出来,然后用箭头表示大小,小的指向大的,就可以得到下面的图:

显然,所有灰色区域内的元素都能确定大于 $x$,而灰色区域内包含大约 $3\times\dfrac n{10}$ 个元素,所以 $x$ 就在数组中 $30\%$ 到 $70\%$ 的地方。
那么我们就可以分析复杂度了,设 $T(n)$ 为 BFPRT 算法查询 $n$ 个元素的第 $k$ 小时的操作次数,考虑最差情况,即每次递归都会选取 $\dfrac{7n}{10}$ 个元素递归下去,则有递推式:
$$T(1)=O(1),T(n)=T\left(\frac n5\right)+T\left(\frac{7n}{10}\right)+O(n)$$
我们设存在一个常数 $c$ 使得 $T(n)\le cn$,并设后面的 $O(n)\le an$,则有:
$$\begin{aligned}T(n)&\le\frac{cn}5+\frac{7cn}{10}+an\\&=\frac{9cn}{10}+an\\&=cn+\left(an-\frac{cn}{10}\right)\end{aligned}$$
我们令 $c\ge10a$,就可以得到 $T(n)\le cn$ 的结论,因此,$T(n)=O(n)$。
参考代码:
```cpp
#include <cstdio>
using namespace std;
int BFPRT(int * a, int n, int k);
int median_of_median(int * a, int n);
/*
对 [a, a + n) 中的元素重新排列, 其中小于 x 的都在左边, 大于等于 x 的都在右边.
需要保证 x 为 [a, a + n) 的元素.
*/
int partition(int * a, int n, int x) {
int * b = new int [n];
int l = 0, r = n - 1;
for (int i = 0; i < n; i++) {
if (a[i] < x) {
b[l] = a[i];
l++;
} else {
b[r] = a[i];
r--;
}
}
for (int i = 0; i < n; i++) {
a[i] = b[i];
}
return l;
}
/*
找到 [a, a + n) 中的第 k 小.
*/
int BFPRT(int * a, int n, int k) {
if (n == 1) {
return a[0];
} else {
int x = median_of_median(a, n);
int p = partition(a, n, x);
if (k == p) {
return a[x];
} else if (k < p) {
return BFPRT(a, p - 1, k);
} else {
return BFPRT(a + p + 1, n - p - 1, k - p - 1);
}
}
}
void swap(int & a, int & b) {
a ^= b ^= a ^= b;
}
/*
暴力计算 [a, a + n) 的中位数.
*/
int compute_median(int * a, int n) {
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (a[i] > a[j]) {
swap(a[i], a[j]);
}
}
}
return a[n / 2];
}
/*
计算 BFPRT 中作为分界点的 x.
*/
int median_of_median(int * a, int n) {
int * median = new int [(n + 4) / 5];
for (int i = 0; i < n / 5; i++) {
median[i] = compute_median(a + i * 5, 5);
}
if (n % 5 != 0) {
median[n / 5] = compute_median(a + n / 5 * 5, n % 5);
}
return BFPRT(a, (n + 4) / 5, (n + 4) / 10);
}
int main() {
int n, k;
scanf("%d %d", &n, &k);
int * a = new int [n];
for (int i = 0; i < n; i++) {
scanf("%d", &a[i]);
}
printf("%d", BFPRT(a, n, k));
return 0;
}
```