题解:P10998 【MX-J3-T3+】Tuple+

· · 题解

思路

考虑枚举每一组 (a,b,c) ,去寻找 d

具体实现上将每一组 (a,b,c) 根据 (a,b) 储存在集合中,之后枚举每一组 a,b,c 根据 (a,b)(a,c)(b,c) 、 可以找到 d

代码

// Problem: T500879 【MX-J3-T3】Tuple
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/T500879?contestId=193566
// Memory Limit: 512 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include<bits/stdc++.h>

#define zibenlun "\n========================================================\n"
#define int long long
#define int_128 __int128
#define debug cout<<1;
#define lowbit(x) (x&(-x))
#define faster ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
using namespace std;
inline int read() {
    int s=0,flag=0;
    char ch=getchar();
    while((ch<'0'||ch>'9')&&(ch!='-')) ch=getchar();
    if(ch=='-') {
        flag=1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9') {
        s=s*10+(ch^48);
        ch=getchar();
    }
    if(flag) return -s;
    return s;
}
inline void write(int x) {
    if(x<0) putchar('-'),x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
unordered_map<int,int> mp[300005];
unordered_set<int> s[300005];
int n,m;
struct nd{
    int a,b,c;
}a[300005];
int ans;
int cnt;
signed main() {
//  freopen(".in","r",stdin);
//  freopen(".out","w",stdout);
    faster
    cin>>n>>m;
    for(int i=1;i<=m;i++){
        cin>>a[i].a>>a[i].b>>a[i].c;
        if(mp[a[i].a][a[i].b]==0){
            mp[a[i].a][a[i].b]=++cnt;
        }
        s[mp[a[i].a][a[i].b]].insert(a[i].c);
    }
    for(int i=1;i<=m;i++){
        if(a[i].c==n) continue;
        if(s[mp[a[i].b][a[i].c]].size()==0||s[mp[a[i].a][a[i].c]].size()==0) continue;
        int minn=min(min(s[mp[a[i].b][a[i].c]].size(),s[mp[a[i].a][a[i].b]].size()),s[mp[a[i].a][a[i].c]].size());
        if(s[mp[a[i].a][a[i].b]].size()==minn)
            for(auto j:s[mp[a[i].a][a[i].b]]){
                if(j==a[i].c) continue;
                if(s[mp[a[i].b][a[i].c]].count(j)&&s[mp[a[i].a][a[i].c]].count(j)){
                    ans++;
                }
            }
        else if(s[mp[a[i].b][a[i].c]].size()==minn)
            for(auto j:s[mp[a[i].b][a[i].c]]){
                if(s[mp[a[i].a][a[i].b]].count(j)&&s[mp[a[i].a][a[i].c]].count(j)){
                    ans++;
                }
            }
        else 
            for(auto j:s[mp[a[i].a][a[i].c]]){
                if(s[mp[a[i].a][a[i].b]].count(j)&&s[mp[a[i].b][a[i].c]].count(j)){
                    ans++;
                }
            }
    }
    cout<<ans;
    return 0;
}