CF1654G Snowy Mountain 题解

· · 题解

洛谷传送门:Snowy Mountain

贡献一个 O(n \log n) 解法。

初始可以用 BFS 求出每个点的高度 h_u

接下来我们来考虑如何使得操作数最大化,这点其它题解讲的很明白,从点 u 出发的答案即为 2h_u-h_v,其中 v 是从 u 出发能到达的高度最小的具有相同高度邻点的点。

考虑使用点分治计算每个点 u 所对应的 h_v,设当前分治连通块的重心为 x,连通块大小为 s,先用 DFS O(s) 计算出从 x 出发到达连通块内每个点 v 所需的最小初始动能,具体做法是维护变量 sum,高度下降则 -1,不变则 +1,到 v 点时所需的最小动能即为过程中 sum 的最大值。

使用同样的方法,我们也可以用第二遍 DFS 处理出从连通块内每个点 u 是否可以到达 x,以及到达时的动能。此时在 O(s \log s) 时间内处理每个连通块的方式就比较明朗了:将所有能够到达的、具有相同高度邻点的 v 按照到达它所需的最小动能排序,然后按照 h_v 做一个前缀min,在第二遍 DFS 到某个点 u 时直接在这个有序序列上二分查找其所对应的最小的 h_v 即可,因为是取min而不是求和,所以可以不用担心 uv 来自同一棵子树的情况。

这样复杂度是 O(n \log^2 n) 的,然后我们还有很好的一个性质没有用到:每次变量 sum 只会 +1-1

这意味着变量 sum 的值是 O(s) 的,这就允许我们可以不用排序,而是直接在值域上操作,这样处理单个分治连通块的复杂度就降到了 O(s),总复杂度 O(n \log n)

代码:

//CF1654G Snowy Mountain
//O(nlogn)解法
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)
#define eb emplace_back
#define pb push_back
#define mp make_pair
using namespace std;

typedef long long ll;
const int N = 2e5+5;
const int Mod = 998244353;

int n, h[N], ans[N];
int siz[N];
bool b[N], vis[N];

vector<int> T[N];
queue<int> Q;

void dfs_centre(int x, int fa, int sizz, int &c) {
    siz[x] = 1;
    int mx = 0;
    for (auto son : T[x]) {
        if (son == fa || vis[son]) continue;
        dfs_centre(son, x, sizz, c);
        siz[x] += siz[son];
        mx = max(mx, siz[son]);
    }
    if (mx <= sizz / 2 && siz[x] >= sizz / 2) c = x;
}

int L, R, mn[N * 2];

void dfs(int x, int fa, int mx, int sum, bool t) {
    mx = max(mx, t ? sum : -sum);
    if (t) {
        if (sum >= mx && sum + n >= L)
            ans[x] = min(ans[x], mn[min(R, sum + n)]);
    } else if (b[x]) {
        while (R < mx + n) mn[++R] = n;
        while (L > mx + n) mn[--L] = n;
        mn[mx + n] = min(mn[mx + n], h[x]);
    }

    for (auto son : T[x]) {
        if (son == fa || vis[son]) continue;
        if (h[son] == h[x]) dfs(son, x, mx, sum - 1, t);
        else if (t ^ (h[x] > h[son])) dfs(son, x, mx, sum + 1, t);
    }
}

inline void calc(int x) {
    L = R = mn[n] = n;
    dfs(x, x, 0, 0, 0);
    for (int i = L + 1; i <= R; i++) mn[i] = min(mn[i], mn[i - 1]);
    dfs(x, x, 0, 0, 1);
}

void solve(int x, int sizz) {
    dfs_centre(x, x, sizz, x);
    vis[x] = 1;
    calc(x);
    for (auto son : T[x]) {
        if (vis[son]) continue;
        solve(son, siz[son] > siz[x] ? sizz - siz[x] : siz[son]);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &b[i]);
        h[i] = -1;
        if (b[i]) {
            h[i] = 0;
            Q.push(i);
        }
    }

    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        T[u].pb(v), T[v].pb(u);
    }

    while (!Q.empty()) {
        int top = Q.front();
        Q.pop();
        for (auto v : T[top]) if (!~h[v]) {
            h[v] = h[top] + 1;
            Q.push(v);
        }
    }

    for (int i = 1; i <= n; i++) {
        ans[i] = h[i];
        b[i] = 0;
        for (auto j : T[i]) if (h[i] == h[j])
            { b[i] = 1; break; }
    }

    solve(1, n);

    for (int i = 1; i <= n; i++) printf("%d ", 2 * h[i] - ans[i]);
    return 0;
}