P3224 [HNOI2012] 永无乡 题解

· · 题解

感谢 MyShiroko 提供思路!

大家好,我喜欢暴力数据结构,所以我用暴力通过了此题。

发现 vector 启发式合并(把小 vector 暴力复制在大 vector 的后面)的复杂度总共是 O(n \log n) 的,所以我们考虑使用它。

题目还要求我们求第 k 小,朴素的想法是:在查询时,如果这个块没有排好序(执行过合并操作),那么先 sort 一遍。然后直接输出第 k 小即可。

交上去之后发现我们只获得了 60 pts 的好成绩? 数据还是过于淼了。

发现瓶颈在于求第 k 小,sort 当块内元素数量(设为 cnt)过多时,复杂度极劣,再随便卡卡,复杂度直接爆炸。

考虑优化。

为防止上述情况发生,我们可以设立一个阈值 B

cnt \le B 时,直接 sort。

cnt > B 时,暴力把这个块中的元素进行值域分块,那么一次查询复杂度降为 O(\sqrt{n}),可以通过本题。

显然对于每个 cnt > B 的块,我们只需要执行上述操作一次,之后在合并时把小块中元素放入大的块的值域分块数组,这样就保证复杂度正确。

实际情况中取 B \approx 1000 时最优。

喜提最优解。

Code:

#include <vector>
#include <algorithm>
#include <stdio.h>
#include <cmath>
#include <bitset>
// #define int long long

// using namespace std;
using std::bitset;
using std::sort;
using std::swap;
using std::vector;

const int Size = (1 << 20) + 1;
char buf[Size], *p1 = buf, *p2 = buf;
char buffer[Size];
int op1 = -1;
const int op2 = Size - 1;
#define getchar()                                                          \
    (tt == ss && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) \
         ? EOF                                                             \
         : *ss++)
char In[1 << 20], *ss = In, *tt = In;
inline int read()
{
    int x = 0, c = getchar(), f = 0;
    for (; c > '9' || c < '0'; f = c == '-', c = getchar());
    for (; c >= '0' && c <= '9'; c = getchar())
        x = (x << 1) + (x << 3) + (c ^ 48);
    return f ? -x : x;
}
inline void write(int x)
{
    if (x < 0) x = -x, putchar_unlocked('-');
    if (x > 9) write(x / 10);
    putchar_unlocked(x % 10 + '0');
}

vector<int> v[100001];
int n, m;

int fa[100001];
int siz[100001];

int find(int x)
{
    if (x == fa[x]) return x;
    return fa[x] = find(fa[x]);
}

int tot;
const int B = 1024;
unsigned short cnt1[100001 / B + 1][100001], cnt2[100001 / B + 1][1005];
int id[100001];
// bool vis[100001];
bitset<100001> vis;
int to[100001];
int len = 300;

#define add2(id, x)              \
    {                            \
        cnt1[id][x]++;           \
        cnt2[id][x / len + 1]++; \
    }

void add()
{
    int x = read(), y = read();
    x = find(x);
    y = find(y);
    if (x == y) return;
    if (siz[y] > siz[x]) swap(x, y);

    fa[y] = x;
    siz[x] += siz[y];
    vis[x] = 1; // vis 记录这个块是否需要排序

    v[x].reserve(v[y].size() + 1);

    if (id[x]) // 如果大的块已经被值域分块过
    {
        for (int i = 0; i < v[y].size(); i++)
        {
            add2(id[x], v[y][i]);
            v[x].push_back(v[y][i]); // 复制的同时 小块中元素放入大的块的值域分块数组
        }
        return;
    }
    v[x].insert(v[x].end(), v[y].begin(), v[y].end());
}

void query()
{
    int x = read(), k = read();
    x = find(x);

    if (siz[x] > B && !id[x]) // 达到阈值,进行值域分块
    {
        id[x] = ++tot;
        for (int i = 0; i < v[x].size(); i++)
            add2(id[x], v[x][i]);
    }
    if (id[x]) // 值域分块求第 k 小
    {
        for (int i = 1; i <= n / len + 2; i++)
        {
            k -= cnt2[id[x]][i];
            if (k <= 0)
            {
                k += cnt2[id[x]][i];
                int l = (i - 1) * len;
                while (k > 0)
                {
                    k -= cnt1[id[x]][l];
                    if (k <= 0)
                    {
                        write(to[l]);
                        putchar_unlocked('\n');
                        return;
                    }
                    l++;
                }
            }
        }
        puts("-1");
        return;
    }
    else // sort 求第 k 小
    {
        if (vis[x])
        {
            sort(v[x].begin(), v[x].end());
            vis[x] = 0;
        }
        if (v[x].size() < k)
        {
            puts("-1");
            return;
        }
        else write(to[v[x][k - 1]]), putchar_unlocked('\n');
    }
}

int main()
{
    n = read();
    m = read();
    len = sqrt(n);
    register int x, i;
    for (i = 1; i <= n; i++)
    {
        x = read();
        fa[i] = i;
        siz[i] = 1;
        v[i].push_back(x);
        to[x] = i;
    }
    for (i = 1; i <= m; i++) add();
    register int q = read(), op;
    for (i = 1; i <= q; i++)
    {
        op = 0;
        for (; op < 'A' || op > 'Z'; op = getchar());
        if (op == 'B') add();
        else query();
    }
    return 0;
}