【模板】带权并查集

· · 个人记录

对于每一个集合维护点与根的权值。

如果对权值取模,可以得到种类并查集,相当于同时维护数与数之间的多种关系,如包含与不包含。

典型例题:P2024 [NOI2001] 食物链

两种实现方式:(以 P2024食物链 为例)

1.维护点到根的关系值,通过权值判断两点是否在同一集合。

实现如下:

#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 50010
// 0同类 1吃 2被吃 
inline int read()
{
    int s=0,f=1;
    char ch=getchar();
    while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
    while('0'<=ch&&ch<='9') {s=s*10+ch-'0';ch=getchar();}
    return s*f;
}
int fa[N],re[N];
int find(int x)
{
    int f=fa[x];
    if(x!=fa[x])
    {
        fa[x]=find(fa[x]);
        re[x]=(re[x]+re[f])%3;
        return fa[x];
    }
    else return x;
}
void join(int x,int y,int reSon)
{
    int xx=find(x),yy=find(y);
    //re[x]+reFa==re[y]+reSon (mod 3); 
    if(xx!=yy)
    {
        fa[xx]=yy;
        re[xx]=(re[y]-re[x]+reSon+3)%3;
    }
}
int n,m; 
int ans;
int main()
{
    n=read();m=read();
    for(int i=1;i<=n;++i)
    {
        re[i]=0;fa[i]=i;
    }
    while(m--)
    {
        int a,b,c;
        c=read()-1;a=read();b=read();
        if((c&&a==b)||(a>n||b>n))
        {
            ++ans;continue;
        }
        if(!c)
        {
            int aa=find(a),bb=find(b);
            if(aa==bb)
            {
                if(re[a]!=re[b])
                {
                    ++ans;continue;
                }
            }
            else
            {
                join(a,b,c);
            }
        }
        if(c) 
        {
            int aa=find(a),bb=find(b);
            if(aa==bb)
            {
                //reAb+re[b]==re[a] (mod 3)
                int reAb=(re[a]-re[b]+3)%3;
                if(reAb!=1)
                {
                    ++ans;continue;
                }
            }
            else 
            {
                join(a,b,c);
            }
        }
    } 
    printf("%d",ans);
    return 0;
}

2.维护每个点维护 逆向点(具体实现为其倍数) 通过与一个点的逆向点连边维护两点关系。

实现如下:

#include<bits/stdc++.h>
using namespace std;
int fa[300030],n,k,cnt;
//n^0 同类 n^1吃 n^2被吃
inline int find(int x)
{
    return x==fa[x]?x:fa[x]=find(fa[x]); 
}
inline void join(int x,int y)
{
    if(find(x)!=find(y)) fa[find(y)]=find(x);
}
int main()
{
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n*3;++i) fa[i]=i;
    for(int i=1;i<=k;++i)
    {
        int c,u,v;
        scanf("%d%d%d",&c,&u,&v);
        if(u>n||v>n) 
        {
            cnt++;continue;
        }
        if(c==2&&u==v) 
        {
            cnt++;continue;
        }
        if(c==1)
        {
            if(find(u+n)==find(v)||find(u+n*2)==find(v)) 
            {
                cnt++;
                continue;
            }
            join(u,v);join(u+n,v+n);join(u+2*n,v+2*n);
        }
        else if(c==2)
        {

            if(find(u)==find(v)||find(u+2*n)==find(v)) 
            {
                cnt++;
                continue;
            }
            join(u,v+2*n);join(u+n,v);join(u+2*n,v+n);
        }
    }
    printf("%d",cnt);
    return 0;
}

还有一种用法是维护点与点之间的权值关系,如维护距离。

典型例题P1196 [NOI2002] 银河英雄传说

代码参考:

#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 30010 
int fa[N],re[N],size[N];
int T;
int find(int x)
{
    int f=fa[x];
    if(x!=f)
    {
        fa[x]=find(fa[x]);
        re[x]+=re[f];
        return fa[x];
    } 
    return x;
}
void join(int x,int y)
{
    int xx=find(x),yy=find(y);
    if(xx==yy) return;
    fa[xx]=yy;
    //re[x]+reFa==reSon+re[y]
    re[xx]=size[yy];
    size[yy]+=size[xx];
}
int main()
{
    scanf("%d",&T);
    for(int i=1;i<=30000;++i) 
    {
        fa[i]=i;re[i]=0;size[i]=1;
    }
    while(T--)
    {
        char ch;
        do{ch=getchar();}while(ch!='M'&&ch!='C');
        int a,b;
        scanf("%d%d",&a,&b);
        if(ch=='M')
        {
            //a->b
            join(a,b);
        }
        if(ch=='C')
        {
            int aa=find(a),bb=find(b);
            if(aa!=bb)
            {
                printf("-1\n");
                continue;
            }
            else 
            {
                printf("%d\n",max(abs(re[a]-re[b])-1,0));
                continue;
            }
        } 
    }
    return 0;
}