你所不了解的算法——BFPRT 算法

· · 个人记录

给定一个长 n 的无序数组和一个 1\sim n 之间的整数 k,求出该无序数组中第 k 大的元素。

—— Top k 问题。

算法一

我会排序!

## 算法二 我会排序! 进行 $k$ 轮冒泡排序,这样就得到了前 $k$ 小。 时间复杂度 $O(nk)$,在 $k$ 小的离谱的时候有用。 ## 算法三 我会使用堆! 维护一个大小为 $k$ 的堆,先把前 $k$ 个元素入堆,然后遍历后面的元素,如果这个元素比堆顶小,弹掉堆顶并插入这个元素,遍历完成后依次弹掉堆顶,最后一个弹掉的元素就是第 $k$ 小。 时间复杂度 $O(n\log k)$。 ## 算法四 我知道快速排序的原理! 假设我们要求一个数组 $a$ 的第 $k$ 大,我们从 $a$ 中随机选择一个分界点 $x$,将小于 $x$ 的元素放在左边,大于 $x$ 的元素放在右边,$x$ 放在中间。 如果左边的元素个数大于 $k$,递归左边的元素,且 $k$ 不变。 如果左边的元素个数等于 $k$,返回 $x$。 如果左边的元素个数大于 $k$,递归右边的元素,这时 $k$ 要减去左边的元素再减 $1$(要排除 $x$)。 期望时间复杂度 $O(n)$。 ## 算法五 使用 BFPRT 算法(也叫 Median of Medians 算法)。 该算法由 Blum、Floyd、Pratt、Rivest、Tarjan 发明,再此向 Tarjan 爷爷表示膜拜。 先把 $n$ 个数 $5$ 个一组分成 $\left\lceil\dfrac n5\right\rceil$ 组,其中只有最后一组有 $n\bmod5$ 个数。 然后对分成 $\left\lceil\dfrac n5\right\rceil$ 组数,暴力计算出这些数组的中位数,得到了一个长 $\left\lceil\dfrac n5\right\rceil$ 的数组,然后**调用 BFPRT 算法计算出中位数**(注意这里必须调用 BFPRT,不然复杂度分析会失效)。 得到这 $\left\lceil\dfrac n5\right\rceil$ 个数的中位数后,将这个数作为 $x$ 进行划分,然后根据原来的算法递归调用。 我们发现,这个算法的奇怪之处就是 $x$ 的选取,显然,选择的 $x$ 越接近中间,那么算法的效率就越高。 我们将 $a$ 的元素五个一列画出来,然后用箭头表示大小,小的指向大的,就可以得到下面的图: ![](https://cdn.luogu.com.cn/upload/image_hosting/23kppnjm.png) 显然,所有灰色区域内的元素都能确定大于 $x$,而灰色区域内包含大约 $3\times\dfrac n{10}$ 个元素,所以 $x$ 就在数组中 $30\%$ 到 $70\%$ 的地方。 那么我们就可以分析复杂度了,设 $T(n)$ 为 BFPRT 算法查询 $n$ 个元素的第 $k$ 小时的操作次数,考虑最差情况,即每次递归都会选取 $\dfrac{7n}{10}$ 个元素递归下去,则有递推式: $$T(1)=O(1),T(n)=T\left(\frac n5\right)+T\left(\frac{7n}{10}\right)+O(n)$$ 我们设存在一个常数 $c$ 使得 $T(n)\le cn$,并设后面的 $O(n)\le an$,则有: $$\begin{aligned}T(n)&\le\frac{cn}5+\frac{7cn}{10}+an\\&=\frac{9cn}{10}+an\\&=cn+\left(an-\frac{cn}{10}\right)\end{aligned}$$ 我们令 $c\ge10a$,就可以得到 $T(n)\le cn$ 的结论,因此,$T(n)=O(n)$。 参考代码: ```cpp #include <cstdio> using namespace std; int BFPRT(int * a, int n, int k); int median_of_median(int * a, int n); /* 对 [a, a + n) 中的元素重新排列, 其中小于 x 的都在左边, 大于等于 x 的都在右边. 需要保证 x 为 [a, a + n) 的元素. */ int partition(int * a, int n, int x) { int * b = new int [n]; int l = 0, r = n - 1; for (int i = 0; i < n; i++) { if (a[i] < x) { b[l] = a[i]; l++; } else { b[r] = a[i]; r--; } } for (int i = 0; i < n; i++) { a[i] = b[i]; } return l; } /* 找到 [a, a + n) 中的第 k 小. */ int BFPRT(int * a, int n, int k) { if (n == 1) { return a[0]; } else { int x = median_of_median(a, n); int p = partition(a, n, x); if (k == p) { return a[x]; } else if (k < p) { return BFPRT(a, p - 1, k); } else { return BFPRT(a + p + 1, n - p - 1, k - p - 1); } } } void swap(int & a, int & b) { a ^= b ^= a ^= b; } /* 暴力计算 [a, a + n) 的中位数. */ int compute_median(int * a, int n) { for (int i = 0; i < n; i++) { for (int j = i + 1; j < n; j++) { if (a[i] > a[j]) { swap(a[i], a[j]); } } } return a[n / 2]; } /* 计算 BFPRT 中作为分界点的 x. */ int median_of_median(int * a, int n) { int * median = new int [(n + 4) / 5]; for (int i = 0; i < n / 5; i++) { median[i] = compute_median(a + i * 5, 5); } if (n % 5 != 0) { median[n / 5] = compute_median(a + n / 5 * 5, n % 5); } return BFPRT(a, (n + 4) / 5, (n + 4) / 10); } int main() { int n, k; scanf("%d %d", &n, &k); int * a = new int [n]; for (int i = 0; i < n; i++) { scanf("%d", &a[i]); } printf("%d", BFPRT(a, n, k)); return 0; } ```