题解:AT_abc396_e [ABC396E] Min of Restricted Sum
题解:E - Min of Restricted Sum
提供一个非常容易想的思路,但是赛时没调出来。
一个显然的操作就是把每个数二进制拆分一下,对于每个数位求出其对答案的贡献。
我们设一个数
我们根据以上分析建图,分两种情况:
- 若
f_i(Z_j) = 0 ,则f_i(A_{X_j}) 与f_i(A_{Y_j}) 相同,所以只要使用并查集把X_j 和Y_j 合并到一联通块中; - 若
f_i(Z_j) = 1 ,则f_i(A_{X_j}) 与f_i(A_{Y_j}) 不同,就把X_j 所在的联通块和Y_j 所在的联通块连边(注意我是把联通块当成点建边的),代表它们所在的两个联通块在数值上不同。
在新建的图上染一下色,由于要使
实现细节:
- 由于新建的图不一定联通,记得要跑多遍
dfs(就因为这个没在赛时调出来); - 我们是用联通块缩成点建的边,但其实显然只要把并查集中每个联通块的根节点连边就行了,此外记得计算染
0 的个数时要乘上联通块大小; - 至于你问我联通块大小怎么求,我是在并查集中顺带计算的大小,此时不适合用路径压缩,所以就用的启发式合并优化并查集的时间复杂度;
- 记得每一个二进制位都要跑一遍染色,记得清空
vector和数组!
我的代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
int n, m, x[N], y[N], z[N];
// 并查集封装版
struct dsu {
int fa[N], siz[N], n;
void init(int n) {for (int i = 1; i <= n; i++) fa[i] = i, siz[i] = 1; }
int find(int x) { return x == fa[x] ? x : find(fa[x]); }
void unite(int x, int y) {
x = find(x), y = find(y);
if (x == y) return;
if (siz[x] > siz[y]) swap(x, y);
siz[y] += siz[x], fa[x] = y;
}
} same[35];
// 把边先存下来以便之后操作
vector<pair<int, int>> diff[36];
// 在图上跑染色
vector<int> G[N];
int col[N], res;
vector<int> sta; // 记录这一次有哪些点被染的色,如果不优需要把这些点染的色取反
void dfs(int u, int c, int j, int b) {
col[u] = c; sta.push_back(u);
if (c) res += same[j].siz[u];
for (int v : G[u]) {
if (col[v] == -1) dfs(v, !c, j, b);
else if (col[v] == c) {
cout << -1;
exit(0);
}
}
}
int a[N];
int main() {
ios::sync_with_stdio(false), cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= m; i++)
cin >> x[i] >> y[i] >> z[i];
for (int j = 0; j <= 30; j++) same[j].init(n);
for (int i = 1; i <= m; i++) {
for (int j = 0; j <= 30; j++) {
int cur = z[i] >> j & 1;
if (cur == 0) same[j].unite(x[i], y[i]);
else diff[j].push_back({x[i], y[i]});
}
}
for (int j = 0; j <= 30; j++) {
memset(G, 0, sizeof(G));
for (auto [u, v] : diff[j]) {
u = same[j].find(u), v = same[j].find(v);
G[u].push_back(v);
G[v].push_back(u);
}
int ans = 0;
memset(col, -1, sizeof(col));
for (int i = 1; i <= n; i++) {
if (same[j].fa[i] != i) continue;
if (col[i] == -1) {
sta.clear();
res = 0, dfs(i, 0, j, i);
int siz = 0;
for (int i : sta) siz += same[j].siz[i];
if (res < siz - res) { // i染0更优
ans += res;
} else { // i染1更优
for (int x : sta) col[x] = !col[x];
ans += siz - res;
}
}
}
for (int i = 1; i <= n; i++)
a[i] += col[same[j].find(i)] << j; // 统计答案
}
for (int i = 1; i <= n; i++)
cout << a[i] << " \n"[i == n];
return 0;
}
由于是赛时代码改过的,实现免不了有些难看,常数也有点大。我尽量改的通俗易懂一些,如果还有不懂的可以找我解答。