CSP-S2025 T2

· · 个人记录

赛事

容易想到的一种做法。
先跑一遍原图最小生成树,只存树上的边。
枚举村庄城市化方案,暴力加边,重新做最小生成树。
复杂度 O(2^k \cdot n \cdot k \cdot (\log(n \cdot k) + \alpha(n)))
容易发现一种优化。
发现新的方案的答案的边一定存在于其子集的答案中(和上面做法的第一步类似)。所以我们可以随意选择两个子集(并为当前集合),暴力加这两个子集答案的边,我们发现边数减少到了 2 \cdot n 级别。
所以我们的时间复杂度到了 O(2^k \cdot 2 \cdot n \cdot (\log(2 \cdot n) + \alpha(n))).

赛后

赛后发现赛时做法并非正解,而是 80pts。感觉我赛时蠢得离谱,赛后经机房大佬yzh点拨,发现了一种简单的去掉 \log 的方法。

只需要在做之前先对整个图(城市间和城市与村庄间的所有边)预排序,这样每次求答案的时候便不需要再排序了(因为没有新加边)。

这样做时间复杂度到了 O(m \log m + 2 ^ k \cdot n \cdot \alpha(n + k)) 直接就非常优啊。

最终代码

#include<bits/stdc++.h>
using namespace std;
#define N 10004
int n, m, k, f[N + 10], now[(N << 1) + 10], book[15][N];
int mem[1025][N + 10], c[15];
long long ans;
struct node { int u, v, w; } mp[N * 10 + 1000000];
inline bool cmp(node x, node y) { return x.w > y.w; }
inline int find(int x) { return f[x] ^ x ? f[x] = find(f[x]) : x; }
inline void read(int &ret) {
    ret = 0;
    char c = getchar_unlocked();
    while(c > '9' || c < '0') 
        c = getchar_unlocked();
    while(c >= '0' && c <= '9') 
        ret = (ret << 3) + (ret << 1) + (c ^ 48), c = getchar_unlocked();
}
inline void check(int x) {
    if(mem[x][0]) return;
    int cnt = 0;
    for(int i = k - 1; ~i; --i) {
        if(x & (1 << i)) {
            check(x - (1 << i));
            int l = 1, r = 1;
            while(l <= mem[x - (1 << i)][0] && r <= n) {
                while(r <= n && mem[x - (1 << i)][l] < book[i][r]) now[++cnt] = book[i][r++];
                now[++cnt] = mem[x - (1 << i)][l++];
            }
            while(l <= mem[x - (1 << i)][0]) now[++cnt] = mem[x - (1 << i)][l++];
            while(r <= n) now[++cnt] = book[i][r++];
            break;
        }
    }
    long long ret = 0;
    for(int i = k - 1; ~i; --i) 
        if(x & (1 << i)) ret += c[i];
    for(int i = n + k; i; --i) f[i] = i;
    for(int i = 1; i <= cnt; ++i) {
        int uu = find(mp[now[i]].u), vv = find(mp[now[i]].v);
        if(uu ^ vv) {
            ret += mp[now[i]].w, f[uu] = vv, mem[x][++mem[x][0]] = now[i];
        }
    }
    ans = ans < ret ? ans : ret;
}
int main() {
    read(n), read(m), read(k);
    for(int i = m; i; --i) read(mp[i].u), read(mp[i].v), read(mp[i].w);
    for(int i = 0; i < k; ++i) {
        read(c[i]);
        for(int j = 1; j <= n; ++j) {
            int aa;
            read(aa);
            mp[m + i * n + j] = {j, n + i + 1, aa};
        }
    }
    sort(mp + 1, mp + m + n * k + 1, cmp);
    for(int i = n + k; i; --i) f[i] = i;
    for(int i = m + n * k; i; --i) {
        int uu = find(mp[i].u), vv = find(mp[i].v);
        if((uu ^ vv) && vv <= n) 
            ans += mp[i].w, f[uu] = vv, mem[0][++mem[0][0]] = i;
        else if(vv > n)
            book[vv - n - 1][++book[vv - n - 1][0]] = i;
    }
    for(int i = (1 << k) - 1; i; --i) check(i);
    printf("%lld", ans);
    return 0;
}