并查集

· · 算法·理论


所以就有了这篇文章。

模板题 和 双倍经验 P1551 亲戚。

并查集,顾名思义,就是支持合并和查询的集合。用于管理元素所属集合的数据结构,实现为一个森林,其中每棵树表示一个集合,树中的节点表示对应集合中的元素。

集合怎么可以合并呢?由于并查集中每个集合对应森林中的一棵树,所以你就可以顺利的理解这幅图:

那么它所对应的集合就应分别是 \{1,2,3,4,5\}\{6,7,8\}

合并的实现是使得两个集合拥有同一个代表元素,一种简单的实现方法是将一个集合的根节点连接到另一个集合的根节点,例如:

那么如何查询呢?根据我们在前文对该集合的定义可以得知,该集合可以抽象成一棵树。树有什么东西是唯一的呢?——显然,根节点是唯一的。如果两个元素在同一个集合中,那么他们的根节点就一定相同;反之,不一定相同(例如两个集合\{1,2,3,4,5\}\{1,7,8\})。但是很幸运,本题 集合中的元素为序号(1\sim N),也就代表着每个元素都不相同。

所以查询的实现方法就得出了,即是找自身所在树的根节点。显然,根节点相同的两个元素处在同一个集合中。

最朴素的查询和合并的时间复杂度都是 O(N)(当该集合退化到一条链时)。

这样有些慢,我们可以使用路径压缩来优化它。
如果当前节点不是根节点,则继续查找当前节点的父节点的根节点。同时,将当前节点的父节点直接设为最终找到的根节点,这样下次从从同一个节点找其根节点时所经过的节点数会大大减少。时间复杂度优化至 O(\log N)
由于合并只需要查找当前节点的根节点并将根节点的父节点设置为另一个点的根节点,所以时间复杂度也是 O(\log N)
具体代码如下:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e5+10;
int f[MAXN];//f数组表示每个元素的父亲
int find(int x)
{
    if(x==f[x]) return x;
    return f[x]=find(f[x]);
}
int main()
{
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
    ios::sync_with_stdio(false);
    cin.tie(nullptr);cout.tie(nullptr);
    int n,m;
    int z,x,y;
    cin>>n>>m;
    //每个元素初始在独立的一个集合里,所以的初始父亲就是自己
    for(int i=0;i<n;i++) f[i]=i;
    for(int i=0;i<m;i++)
    {
        cin>>z>>x>>y;
        if(z==1)
        {
            if(find(x)!=find(y)) f[find(x)]=find(y);
        }
        if(z==2)
        {
            if(find(x)==find(y)) cout<<"Y"<<endl;
            else cout<<"N"<<endl;
        }
    }
    return 0;
}

上文给出的代码已经可以通过模板题,但是我们还有一种更快的做法:按秩合并(也是启发式合并的一种)。

对于一个集合,秩指的就是这个集合中元素的个数。在并查集的合并操作中,如果我们将秩较小的集合合并到秩更大的集合上,那么比起将秩较大的集合合并到秩更小的集合上,在每一次查询操作中,前者总能花费更少的时间。

在并查集中,我们把所有集合合并起来后,增加的总代价也不会超过 N\log N。也就是说,单次查询的平均时间复杂度为 O(\log N)

那么同时使用路径压缩和按秩合并的优化呢?如果我们这么做的话,单次查询操作的时间复杂度会变成一个奇怪的东西:O(α(N))。其中 α(N) 为反阿克曼函数,它是一个比 \log N 增长得还要慢很多的函数。
由于作者太弱了,复杂度证明看这里:https://oi-wiki.org/ds/dsu-complexity/

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e5+10;
int f[MAXN],size[MAXN];//size数组记录各集合的秩
int find(int x)
{
    if(x==f[x]) return x;
    return f[x]=find(f[x]);
}
void merge(int x,int y)
{
    if(size[find(x)]>size[find(y)]) 
    {
        f[find(y)]=find(x);
        size[find(x)]+=size[find(y)];
    } 
    else 
    {
        f[find(x)]=find(y);
        size[find(y)]+=size[find(x)];
    }
}
int main()
{
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
    ios::sync_with_stdio(false);
    cin.tie(nullptr);cout.tie(nullptr);
    int n,m;
    int z,x,y;
    cin>>n>>m;
    for(int i=0;i<n;i++)
    { 
        f[i]=i;
        size[i]=1;//最开始每个集合的秩都是1
    }
    for(int i=0;i<m;i++)
    {
        cin>>z>>x>>y;
        if(z==1)
        {
            if(find(x)!=find(y)) merge(x,y);
        }
        if(z==2)
        {
            if(find(x)==find(y)) cout<<"Y"<<endl;
            else cout<<"N"<<endl;
        }
    }
    return 0;
}

但这样还是有些问题,注意到我们多次调用了 find() 函数,并且 size 在 C++17时是关键词,我们或许可以优化一下:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e5+10;
int f[MAXN],siz[MAXN];
int find(int x)
{
    if(x==f[x]) return x;
    return f[x]=find(f[x]);
}
void merge(int x,int y)
{
    int fx=find(x),fy=find(y);
    if(siz[fx]>siz[fy]) f[fy]=fx,siz[fx]+=siz[fy];
    else f[fx]=fy,siz[fy]+=siz[fx];
}
int main()
{
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
    ios::sync_with_stdio(false);
    cin.tie(nullptr);cout.tie(nullptr);
    int n,m;
    int z,x,y;
    cin>>n>>m;
    for(int i=0;i<n;i++)
    { 
        f[i]=i;
        siz[i]=1;
    }
    for(int i=0;i<m;i++)
    {
        cin>>z>>x>>y;
        if(z==1)
        {
            if(find(x)!=find(y)) merge(x,y);
        }
        if(z==2)
        {
            if(find(x)==find(y)) cout<<"Y"<<endl;
            else cout<<"N"<<endl;
        }
    }
    return 0;
}