并查集进阶

· · 个人记录

银河英雄传说

边带权并查集模板。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define debug printf("\n-------------\n")
using namespace std;
typedef long long ll;
const int INF1=(~0u>>2);
const ll INF2=(~0ull>>2),P1=10007,P2=998244353,P3=1000000007;
int T,fa[30010],a[30010],cnt[30010];
int ab(int x){return x>=0?x:-x;}
int find(int u)
{
    if(fa[u]==u)
        return fa[u];
    int f=find(fa[u]);
    a[u]+=a[fa[u]];
    return fa[u]=f;
}
int main()
{
    scanf("%d",&T);
    for(int i=1;i<=30000;i++)
    {
        fa[i]=i;
        a[i]=0;
        cnt[i]=1;
    }
    while(T--)
    {
        char opt;
        int x,y,fx,fy;
        scanf(" %c %d %d",&opt,&x,&y);
        fx=find(x);
        fy=find(y);
        if(opt=='M')
        {
            a[fx]+=cnt[fy];
            fa[fx]=fy;
            cnt[fy]+=cnt[fx];
            cnt[fx]=0;
        }
        else
            if(fx!=fy)
                printf("-1\n");
            else
                printf("%d\n",ab(a[x]-a[y])-1);
    }
    return 0;
}

奇偶游戏

仍用边带权并查集。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<map>
#include<string>
#define debug printf("\n-------------\n")
using namespace std;
typedef long long ll;
const int INF1=(~0u>>2);
const ll INF2=(~0ull>>2),P1=10007,P2=998244353,P3=1000000007;
map<int,int> fa,d;
int n,m;
int find(int u)
{
    if(fa.count(u)==0||fa[u]==u)
    {
        d[u]=0;
        return fa[u]=u;
    }
    else
    {
        int f=find(fa[u]);
        d[u]^=d[fa[u]];
        return fa[u]=f;
    }
}
int main()
{
    scanf("%d %d",&n,&m);
    for(int i=1;i<=m;i++)
    {
        bool f;
        int u,v,fu,fv;
        string s;
        cin>>u>>v>>s;
        f=(s=="odd");
        u--;
        fu=find(u);
        fv=find(v);
        if(fu==fv)
        {
            if(d[u]^d[v]!=f)
            {
                printf("%d",i-1);
                return 0;
            }
            continue;
        }
        d[fu]=d[u]^d[v]^f;
        fa[fu]=fv;
    }
    printf("%d",m);
    return 0;
}

食物链

分种类的并查集

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
int N,K,sum=0,a[300010];
int f(int _){return a[_]==_?_:(a[_]=f(a[_]));}
int main()
{
    scanf("%d %d",&N,&K);
    for(int i=1;i<=3*N;i++)
        a[i]=i;
    while(K--)
    {
        int x,y,z;
        scanf("%d %d %d",&z,&x,&y);
        if(x>N||y>N)
        {
            sum++;
            continue;
        }
        if(z==1)
            if(f(x+N)==f(y)||f(y+N)==f(x))
                sum++;
            else
            {
                a[f(x)]=f(y);
                a[f(x+N)]=f(y+N);
                a[f(x+2*N)]=f(y+2*N);
            }
        else
            if(f(x)==f(y)||f(x)==f(y+N))
                sum++;
            else
            {
                a[f(x)]=f(y+2*N);
                a[f(x+N)]=f(y);
                a[f(x+2*N)]=f(y+N);
            }
    }
    printf("%d",sum);
    return 0;
}

关押罪犯

排序之后用拓展域。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
struct node
{
    int u,v,w;
}c[100010];
int N,M,a[100010],e[100010];
int f(int _){return a[_]==_?_:(a[_]=f(a[_]));}
bool cmp(node _,node __){return _.w>__.w;}
int main()
{
    scanf("%d %d",&N,&M);
    for(int i=1;i<=M;i++)
        scanf("%d %d %d",&c[i].u,&c[i].v,&c[i].w);
    for(int i=1;i<=N;i++)
        a[i]=i;
    sort(c+1,c+M+1,cmp);
    for(int i=1;i<=M+1;i++)
    {
        if(f(c[i].u)==f(c[i].v))
        {
            printf("%d",c[i].w);
            return 0;
        }
        if(!e[c[i].u])
            e[c[i].u]=c[i].v;
        else
            a[f(a[e[c[i].u]])]=f(a[c[i].v]);
        if(!e[c[i].v])
            e[c[i].v]=c[i].u;
        else
            a[f(a[e[c[i].v]])]=f(a[c[i].u]);
    }
}

石头剪子布

仍然是扩展域,枚举裁判。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define debug printf("\n-------------\n")
using namespace std;
typedef long long ll;
const int INF1=(~0u>>2);
const ll INF2=(~0ull>>2),P1=10007,P2=998244353,P3=1000000007;
int n,m,fa[2010],q[2010][3],mx,cnt,ans;
int find(int u){return fa[u]==u?u:(fa[u]=find(fa[u]));}
void merge(int u,int v)
{
    int fu=find(u),fv=find(v);
    if(fu!=fv)
        fa[fu]=fv;
}
bool check(int u)
{
    int f1=find(u),f2=find(u+n),f3=find(u+n*2);
    return (f1==f2||f1==f3||f2==f3);
}
int main()
{
    while(scanf("%d %d",&n,&m)!=EOF)
    {
        for(int i=1;i<=m;i++)
        {
            char opt;
            scanf("%d%c%d",&q[i][1],&q[i][0],&q[i][2]);
            q[i][1]++;
            q[i][2]++;
        }
        cnt=0;
        ans=0;
        mx=0;
        for(int i=1;i<=n;i++)
        {
            bool g=true;
            for(int j=1;j<=n*3;j++)
                fa[j]=j;
            for(int j=1;j<=m;j++)
            {
                char opt=q[j][0];
                int u=q[j][1],v=q[j][2];
                if(u==i||v==i)
                    continue;
                if(opt=='=')
                {
                    merge(u,v);
                    merge(u+n,v+n);
                    merge(u+n*2,v+n*2);
                }
                else
                {
                    if(opt=='<')
                        swap(u,v);
                    merge(u,v+n);
                    merge(u+n,v+n*2);
                    merge(u+n*2,v);
                }
                if(check(u)||check(v))
                {
                    mx=max(mx,j);
                    g=false;
                    break;
                }
            }
            if(g)
            {
                cnt++;
                ans=i;
            }
        }
        switch(cnt)
        {
            case 0:printf("Impossible\n");break;
            case 1:printf("Player %d can be determined to be the judge after %d lines\n",ans-1,mx);break;
            default:printf("Can not determine\n");break;
        }
    }
    return 0;
}

真正的骗子

拓展域。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<vector>
#define debug printf("\n-------------\n")
using namespace std;
typedef long long ll;
const int INF1=(~0u>>2);
const ll INF2=(~0ull>>2),P1=10007,P2=998244353,P3=1000000007;
vector<int> adj[1210];
bool vis[1210],dp[1210][1210];
short pre[1210],cnt[1210][1210];
int n,m,q1,q2,fa[1210];
int find(int u){return fa[u]==u?u:(fa[u]=find(fa[u]));}
void merge(int u, int v)
{
    int fu=find(u),fv=find(v);
    if(fu!=fv)
        fa[fu]=fv;
}
int main()
{
    while(1)
    {
        scanf("%d %d %d",&m,&q1,&q2);
        n=q1+q2;
        if(!m&&!n)break;
        memset(dp,false,sizeof(dp));
        memset(pre,0,sizeof(pre));
        memset(cnt,0,sizeof(cnt));
        memset(vis,false,sizeof(vis));
        for(int i=1;i<=2*n;i++)
            adj[i].clear();
        for(int i=1;i<=2*n;i++)
            fa[i]=i;
        for(int i=1;i<=m;i++)
        {
            int u,v;
            string op;
            cin>>u>>v>>op;
            if(u==v)
                continue;
            if(op=="yes")
            {
                merge(u,v);
                merge(u+n,v+n);
            }
            else
            {
                merge(u,v+n);
                merge(u+n,v);
            }
        }
        for(int i=1;i<=n;i++)
            adj[find(i)].push_back(i);
        int mx=0;
        dp[0][0]=false;
        cnt[0][0]=1;
        for(int i=1;i<=n;i++)
        {
            if(!vis[find(i)])
            {
                int f1=find(i),f2=find(i+n),w1=adj[f1].size(),w2=adj[f2].size();
                vis[find(i)]=vis[find(i+n)]=true;
                for(int j=q1+q2;j>=0;j--){
                    if(j>=w1)
                    {
                        dp[i][j]|=dp[mx][j-w1];
                        cnt[i][j]+=cnt[mx][j-w1];
                    }
                    if(j>=w2)
                    {
                        dp[i][j]|=dp[mx][j-w2];
                        cnt[i][j]+=cnt[mx][j-w2];
                    }
                }
                pre[i]=mx;
                mx=i;
            }
        }
        if(cnt[mx][q1]!=1)
        {
            puts("no");
            continue;
        }
        memset(vis,false,sizeof(vis));
        int cur=mx,j=q1;
        while(j)
        {
            int f1=find(cur),f2=find(cur+n),w1=adj[f1].size(),w2=adj[f2].size();
            if(dp[cur][j]==dp[pre[cur]][j-w1])
            {
                for(int k=0;k<w1;k++)
                    vis[adj[f1][k]]=true;
                j-=w1;
            }
            else if(dp[cur][j]==dp[pre[cur]][j-w2])
            {
                for(int k=0;k<w2;k++)
                    vis[adj[f2][k]]=true;
                j-=w2;
            }
            cur=pre[cur];
        }
        for(int i=1;i<=n;i++)
            if(vis[i])
                printf("%d\n",i);
        printf("end\n");
    }
    return 0;
}

Prefix Enlightenment

维护集合大小+扩展域。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<vector>
#define debug printf("\n-------------\n")
using namespace std;
typedef long long ll;
const int INF1=(~0u>>2);
const ll INF2=(~0ull>>2),P1=10007,P2=998244353,P3=1000000007;
vector<int> adj[600010];
bool vis[600010];
char s[600010];
int n,m,fa[600010],ans=0;
ll sz[600010];
int find(int u){return fa[u]==u?u:(fa[u]=find(fa[u]));}
bool check(int u,int v){return find(u)==find(v);}
void merge(int u,int v)
{
    int fu=find(u),fv=find(v);
    if(fu!=fv)
    {
        sz[fv]+=sz[fu];
        fa[fu]=fv;
    }
}
int main()
{
    scanf("%d %d",&n,&m);
    scanf("%s",s+1);
    for(int i=1;i<=m*2;i++)
    {
        fa[i]=i;
        if(i<=m)
            sz[i]=1;
    }
    for(int i=1;i<=m;i++)
    {
        int c;
        scanf("%d",&c);
        while(c--)
        {
            int x;
            scanf("%d",&x);
            adj[x].push_back(i);
        }
    }
    for(int i=1;i<=n;i++)
    {
        if(adj[i].size()==1)
        {
            int x=adj[i][0];
            if(vis[x])
                ans-=min(sz[find(x)],sz[find(x+m)]);
            if(s[i]=='0')
                sz[find(x+m)]=INF1;
            else
                sz[find(x)]=INF1;
            ans+=min(sz[find(x)],sz[find(x+m)]);
            vis[x]=true;
        }
        else if(adj[i].size()==2)
        {
            int x=adj[i][0],y=adj[i][1];
            if(check(x,y)||check(x,y+m))
            {
                printf("%d\n",ans);
                continue;
            }
            if(vis[x])
                ans-=min(sz[find(x)],sz[find(x+m)]);
            if(vis[y])
                ans-=min(sz[find(y)],sz[find(y+m)]);
            if(s[i]=='1')
            {
                merge(x,y);
                merge(x+m,y+m);
            }
            else
            {
                merge(x,y+m);
                merge(x+m,y);
            }
            ans+=min(sz[find(x)],sz[find(x+m)]);
            vis[x]=vis[y]=true;
        }
        printf("%d\n",ans);
    }
    return 0;
}