并查集

· · 算法·理论

并查集 Union-Find Set

高效处理集合的合并与查询。

并查集:

  1. 一开始,所有点自己是一个集合。
  2. 连线关系,合并集合。
  3. 查询两点是否在同一个集合

并查集三件套

  1. 查询:\text{find()}
  2. 合并:\text{merge()}
  3. 初始化:\text{init()}

\text{find()}

查询 x 节点的根节点编号

朴素算法:

if(fa[x] == x) return x;
return find(fa[x]);

路径压缩:

int find(int x)
{
    if(fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

\text{merge()}

合并a,\ b所在集合

void merge(int a, int b) {fa[find(a)] = fa[find(b)];}

\text{init()}

初始化,每个节点就是自己的父亲

void init(int n) {for(int i=0;i<=n;i++) fa[i] = i;}

优化

按秩合并

按高度合并

这里的 p 数组是存储父节点的,h 数组是这个节点到根节点估计的高度。

可以证明复杂度优化为 O(\log n)

int p[MAXN];
int h[MAXN];

void Initial(int n){
    for(int i = 0; i < n; i++) {
        p[i] = i;
        h[i] = 1;
    }
}

void Union(int x, int y) {
    int rootx = Find(x);
    int rooty = Find(y);
    if(rootx != rooty){
        if(h[rootx] < h[rooty]){
            p[rootx] = rooty;
        } else if(h[rootx] > h[rooty]){
            p[rooty] = rootx;
        } else {
            p[rootx] = rooty;
            h[rooty]++;
        }
    }
}

按大小合并

这里的 p 数组是存储父节点的,s 数组是这个节点所在的这棵树的所有结点总和。

可以证明复杂度优化为 O(\log n)

int p[MAXN];
int s[MAXN];

void Initial(int n){
    for(int i = 0; i < n; i++) {
        p[i] = i;
        s[i] = 1;
    }
}

void Union(int x, int y){
    int rootx = Find(x);
    int rooty = Find(y);
    if(rootx != rooty){
        if(s[rootx] <= s[rooty]){
            p[rootx] = rooty;
            s[rooty] += s[rootx];
        } else {
            p[rooty] = rootx;
            s[rootx] += s[rooty];
        }
    }
}

并查集扩展

扩展域并查集

P2024 [NOI2001] 食物链


#include<bits/stdc++.h>
using namespace std;

const int MAX = 5e4+5;
int fa[3*MAX], n, m;

int find(int x) {return fa[x] == x ? x : fa[x] = find(fa[x]);}
#define same(x, y) find(x) == find(y)
void merge(int x, int y) {fa[find(x)] = find(y);}
void init() {for(int i=1; i<=3*n; i++) fa[i] = i;}

// 扩展域并查集的核心: 
// 对于相同并查集内的命题们只要这些命题中有一个为真,那么必定全为真
int main()
{
    cin >> n >> m;
    init();

    int ans = 0;
    while(m--)
    {
        int op, x, y; cin >> op >> x >> y;
        if(x > n || y > n) {ans++; continue;}
        if(op == 1)
        {
            // x 与吃 y 的同类或者 x 与被 y 的吃的同类,一定是假话 
            if(same(x, y+n) || same(x, y+2*n)) ans ++;
            else
            {
                merge(x, y); // x 与 y 同类
                merge(x+n, y+n); // 吃 x 的和吃 y 的同类 
                merge(x+2*n, y+2*n); // 被 x 吃的和被 y 吃的同类 
            }
        }
        else
        {
            // x 与 y 同类或者 x 与被 y 吃的同类,一定是假话
            if(same(x, y) || same(x, y+2*n)) ans ++;
            else
            {
                merge(x, y+n); // x 与吃 y 的同类
                merge(x+n, y+2*n); // 吃 x 的与被 y 吃的同类
                merge(x+2*n, y); // 被 x 吃的与 y 同类 
            }
        }
    }
    cout << ans;

    return 0;
}

带权并查集

P1196 [NOI2002] 银河英雄传说

#include <bits/stdc++.h>
using namespace std;

const int MAX = 3e5+5;

// 并查集数组
int parent[MAX];    // 父节点数组
int dist[MAX];      // 到根节点的距离
int size[MAX];      // 集合的大小(飞船数量)

// 初始化并查集
void init() 
{
    for (int i = 1; i <= MAX-5; i++) 
    {
        parent[i] = i;  // 初始时每个节点都是自己的父节点
        dist[i] = 0;    // 到根节点的距离为0
        size[i] = 1;    // 每个集合初始大小为1
    }
}

// 查找操作,带路径压缩
int find(int x) 
{
    if (parent[x] == x) {return x;}

    int root = find(parent[x]);  // 递归找到根节点

    // 路径压缩:在回溯时更新距离
    dist[x] += dist[parent[x]];  // 当前节点距离 = 到父节点的距离 + 父节点到根节点的距离
    parent[x] = root;            // 直接连接到根节点

    return root;
}

// 合并操作:将飞船i所在的列接到飞船j所在的列后面
void merge(int i, int j) {
    int root_i = find(i);  // i的根节点
    int root_j = find(j);  // j的根节点

    if (root_i == root_j) {return;} // 已经在同一列,不需要合并

    // 将i所在的列接到j所在的列后面
    parent[root_i] = root_j;          // 将root_i的父节点设为root_j
    dist[root_i] = size[root_j];      // root_i到新根节点的距离为原来j列的长度
    size[root_j] += size[root_i];     // 更新新列的长度
}

// 处理查询
void solve() 
{
    char command;
    int i, j;
    cin >> command >> i >> j;

    if (command == 'M') 
    {
        // M命令:合并飞船i和j所在的列
        merge(i, j);
    } 
    else 
    {
        // C命令:查询飞船i和j之间的飞船数量
        if (find(i) == find(j)) 
        {
            // 在同一列中,计算它们之间的飞船数量
            // |dist[i] - dist[j]| - 1 表示i和j之间的飞船数量
            cout << abs(dist[i] - dist[j]) - 1 << endl;
        } 
        else 
        {
            // 不在同一列
            cout << -1 << endl;
        }
    }
}

int main() 
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    init();  // 初始化并查集
    while (T--) solve();

    return 0;
}