扩展域并查集

· · 个人记录

扩展域并查集

扩展域并查集的精髓就是把一个节点拆成多个进行合并。

例题:P2024 [NOI2001] 食物链

思路

对于本题而言,可以将一个节点拆成三个:自己,自己的食物,和自己的天敌。

对于操作1,我们首先需要判断xy是否合法。如果它们直间直接联系,那么就判断y是否是x的天敌或者食物;如果xy中间没有直接连边,那么根据题意,只有三类动物。换句话说,如果x的食物是y的天敌,那么x的天敌就是y。但这样显然是不符合条件的,所以我们只需要判断x的天敌与y的食物以及x的食物与y的天敌是否相同即可。如果相同,说明不符合条件,更新答案。反之,就要合并集合。由于xy是同类,所以食物和天敌也一样,直接把三对点所在集合合并即可。

对于操作2,我们还是需要判断是否合法,思路跟操作1基本相同,不再赘述。

Code

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<cstring>
#define ll long long
using namespace std;
const int N=50005;
int n,k,ans;
int fa[N*3];//1~n代表self,n+1~n*2代表eat,n*2+1~n*3代表enemy 
int get(int x){
    if(fa[x]==x) return x;
    else return fa[x]=get(fa[x]);
}
void merge(int x,int y){
    fa[get(x)]=get(y);
}
void merge1(int x,int y){
    merge(x,y);
    merge(x+n,y+n);
    merge(x+2*n,y+2*n);
}
void merge2(int x,int y){
    merge(x,y+2*n);
    merge(x+n,y);
    merge(x+2*n,y+n);
}
int main()
{
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n*3;i++){
        fa[i]=i;
    }
    for(int i=1,op,x,y;i<=k;i++){
        scanf("%d%d%d",&op,&x,&y);
        if(x>n||y>n){
//          printf("%d\n",i);
            ans++;
            continue;
        }
        if(op==1){
            if(get(x+n)==get(y)||get(x+2*n)==get(y)||get(x+n)==get(y+2*n)||get(x+2*n)==get(y+n)){
//              printf("%d\n",i);
                ans++;
            }
            else{
                merge1(x,y);
            }
        }
        else{
            if(get(x+n)==get(y+n*2)||get(x+2*n)==get(y)||get(y+n)==get(x)||get(x)==get(y)){
//              printf("%d\n",i);
                ans++;
            }
            else{
                merge2(x,y);
            }
        }
    }
    printf("%d\n",ans);
    return 0;
}

例题:P5937 [CEOI1999]Parity Game

思路

前缀和的思路真的是非常妙啊

首先对于区间[l,r],将它们的前缀和作为一个元素。那么如果说[l,r]有奇数个1,那么sum_{l-1}sum_r的前缀和奇偶性必定不相等。那么我们可以使用扩展域并查集,将一个前缀和节点拆成两个,一个奇数一个偶数。

对于odd,说明奇偶性不相同。先判一下奇偶性是否相同,相同则不符合条件直接输出答案。反之则直接合并奇数和奇数,偶数和偶数。

对于even同理,然后只需要再离散化一下,这题就解决了。

Code

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<cstring>
#define ll long long
using namespace std;
const int N=50005;
int n,m,sum,tot; 
int fa[2*N],li[2*N];
struct Node{
    int x,y,op;
}a[2*N];
int get(int x){
    if(x==fa[x]) return x;
    else return fa[x]=get(fa[x]);
}
void merge(int x,int y){
    fa[get(x)]=get(y);
}
void lisan(){
    sort(li+1,li+1+tot);
    sum=unique(li+1,li+1+tot)-li;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++){
        string s;
        scanf("%d%d",&a[i].x,&a[i].y);
        cin>>s;
        if(s=="odd"){
            a[i].op=1;
        }
        else{
            a[i].op=2;
        }
        li[++tot]=a[i].x;
        li[++tot]=a[i].y;
    }
    lisan();
    for(int i=1;i<=2*sum+2;i++){
        fa[i]=i;
    }
    for(int i=1;i<=m;i++){
        a[i].x=lower_bound(li+1,li+1+sum,a[i].x-1)-li;
        a[i].y=lower_bound(li+1,li+1+sum,a[i].y)-li;
    }
    for(int i=1;i<=m;i++){
        if(a[i].op==1){
            if(get(a[i].x)==get(a[i].y)){
                printf("%d\n",i-1);
                return 0;
            }
            merge(a[i].x,a[i].y+sum+1);
            merge(a[i].x+sum+1,a[i].y);
        }
        else{
            if(get(a[i].x)==get(a[i].y+sum+1)){
                printf("%d\n",i-1);
                return 0;
            }
            merge(a[i].x,a[i].y);
            merge(a[i].x+sum+1,a[i].y+sum+1);
        }
    }
    printf("%d\n",m);
    return 0;
}

细节

$2.$初始化别忘了循环到$2n+1$,而不是$n+1$。