P8074 [COCI2009-2010#7] SVEMIR 题解
Suzt_ilymtics
·
·
题解
题面,嘿嘿,题面
Discover the key or change the direction.
Solution
题目的意思就是给你 n 个点,任意两个点连边有一定代价,求它的最小生成树。
这看上去很像一个板子,一看数据范围还是算了,\mathcal O(n^2) 的 \text{prim} 算法指定过不去。
那如果把所有边建出来跑 \text{Kruskal} 呢?边数足足有 \frac{(n-1)n}{2} 条也过不去。
我们知道最小生成树只需要 n-1 条边,那这里肯定有很多边是“废边”。
考虑如何减少废边?我们从边权的条件下手。
对于两个点 a,b,它们相连的代价为 w(a,b) = \min \{|a_x-b_x|, |a_y-b_y|, |a_z-b_z|\}。
只取最小值啊。。。
并且还只是三个维度中的一个维度。
考虑三个点 a,b,c 且 a_x < b_x < c_x 且任意两点的代价都是由 x 这一维产生的,那显然连接 (a,b) 和 (b,c) 更优,且 w(a,b) + w(b,c) = w(a,c)。也就是说如果我们按照 x 排个序,可能相连的只有相邻的两个点。那我们只把这些边拿出来就好了嘛。
这样一番处理后,所用边数只有 $3n$ 条,$\text{Kruskal}$ 算法可以通过。
代码:
```cpp
#include<bits/stdc++.h>
#define int long long
#define orz cout<<"lkp AK IOI!"<<endl
using namespace std;
const int MAXN = 1e5+5;
const int INF = 1e9+7;
const int mod = 1e9+7;
struct edge { int u, v, w; }e[MAXN << 2];
struct node { int x, y, z, id; }a[MAXN];
int n, m, sum = 0;
int fa[MAXN];
int read() {
int s = 0, f = 0;
char ch = getchar();
while(!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while(isdigit(ch)) s = (s << 1) + (s << 3) + ch - '0', ch = getchar();
return f ? -s : s;
}
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
int Dis(int x, int y) { return min(abs(a[x].x - a[y].x), min(abs(a[x].y - a[y].y), abs(a[x].z - a[y].z))); }
bool cmp1(node x, node y) { return x.x < y.x; }
bool cmp2(node x, node y) { return x.y < y.y; }
bool cmp3(node x, node y) { return x.z < y.z; }
bool cmp4(edge x, edge y) { return x.w < y.w; }
signed main()
{
n = read();
for(int i = 1; i <= n; ++i) a[i].x = read(), a[i].y = read(), a[i].z = read(), a[i].id = i;
sort(a + 1, a + n + 1, cmp1);
for(int i = 1; i < n; ++i) e[++m] = (edge){a[i].id, a[i + 1].id, Dis(i, i + 1)};
sort(a + 1, a + n + 1, cmp2);
for(int i = 1; i < n; ++i) e[++m] = (edge){a[i].id, a[i + 1].id, Dis(i, i + 1)};
sort(a + 1, a + n + 1, cmp3);
for(int i = 1; i < n; ++i) e[++m] = (edge){a[i].id, a[i + 1].id, Dis(i, i + 1)};
sort(e + 1, e + m + 1, cmp4);
for(int i = 1; i <= n; ++i) fa[i] = i;
for(int i = 1, cnt = 0; i <= m; ++i) {
int uf = find(e[i].u), vf = find(e[i].v);
if(uf != vf) {
fa[uf] = vf;
sum += e[i].w;
if(++cnt == n - 1) break;
}
}
printf("%lld\n", sum);
return 0;
}
```