题解:AT_abc396_e [ABC396E] Min of Restricted Sum

· · 题解

题解:E - Min of Restricted Sum

提供一个非常容易想的思路,但是赛时没调出来。

一个显然的操作就是把每个数二进制拆分一下,对于每个数位求出其对答案的贡献。

我们设一个数 x 的第 i 个二进制位为 f_i(x)(显然 f_i(x) \in \{0,1\}),只考虑所有数第 i 个二进制位可以得到:对于所有的 j \in [1,N] 满足 f_i(A_{X_j}) \oplus f_i(A_{Y_j}) = f_i(Z_j)

我们根据以上分析建图,分两种情况:

  1. f_i(Z_j) = 0,则 f_i(A_{X_j})f_i(A_{Y_j}) 相同,所以只要使用并查集把 X_jY_j 合并到一联通块中;
  2. f_i(Z_j) = 1,则 f_i(A_{X_j})f_i(A_{Y_j}) 不同,就把 X_j 所在的联通块和 Y_j 所在的联通块连边(注意我是把联通块当成点建边的),代表它们所在的两个联通块在数值上不同。

在新建的图上染一下色,由于要使 \sum_{j=1}^N A_j 最小化,即 \sum_{j=1}^N f_i(A_j) 最小化,肯定是染尽量多的 0 是较优的,所以就尝试染起始的联通块为 01,找出较优的方案更新;而如果图中染色出现相邻两个联通块染的色是一样的话就说明无解。

实现细节:

我的代码:

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
int n, m, x[N], y[N], z[N];

// 并查集封装版
struct dsu {
    int fa[N], siz[N], n;
    void init(int n) {for (int i = 1; i <= n; i++) fa[i] = i, siz[i] = 1; }
    int find(int x) { return x == fa[x] ? x : find(fa[x]); }
    void unite(int x, int y) {
        x = find(x), y = find(y);
        if (x == y) return;
        if (siz[x] > siz[y]) swap(x, y);
        siz[y] += siz[x], fa[x] = y;
    }
} same[35]; 
// 把边先存下来以便之后操作
vector<pair<int, int>> diff[36];
// 在图上跑染色
vector<int> G[N];
int col[N], res;
vector<int> sta; // 记录这一次有哪些点被染的色,如果不优需要把这些点染的色取反
void dfs(int u, int c, int j, int b) {
    col[u] = c; sta.push_back(u);
    if (c) res += same[j].siz[u];
    for (int v : G[u]) {
        if (col[v] == -1) dfs(v, !c, j, b);
        else if (col[v] == c) {
            cout << -1;
            exit(0);
        }
    }
}
int a[N];
int main() {
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= m; i++)
        cin >> x[i] >> y[i] >> z[i];
    for (int j = 0; j <= 30; j++) same[j].init(n);
    for (int i = 1; i <= m; i++) {
        for (int j = 0; j <= 30; j++) {
            int cur = z[i] >> j & 1;
            if (cur == 0) same[j].unite(x[i], y[i]);
            else diff[j].push_back({x[i], y[i]});
        }
    }
    for (int j = 0; j <= 30; j++) {
        memset(G, 0, sizeof(G));
        for (auto [u, v] : diff[j]) {
            u = same[j].find(u), v = same[j].find(v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        int ans = 0;
        memset(col, -1, sizeof(col));
        for (int i = 1; i <= n; i++) {
            if (same[j].fa[i] != i) continue;
            if (col[i] == -1) {
                sta.clear();
                res = 0, dfs(i, 0, j, i);
                int siz = 0;
                for (int i : sta) siz += same[j].siz[i];
                if (res < siz - res) { // i染0更优
                    ans += res;
                } else { // i染1更优
                    for (int x : sta) col[x] = !col[x];
                    ans += siz - res;
                }
            }
        }
        for (int i = 1; i <= n; i++)
            a[i] += col[same[j].find(i)] << j; // 统计答案
    }
    for (int i = 1; i <= n; i++)
        cout << a[i] << " \n"[i == n];
    return 0;
}

由于是赛时代码改过的,实现免不了有些难看,常数也有点大。我尽量改的通俗易懂一些,如果还有不懂的可以找我解答。