启智树冲省队组Day5T3 划分

· · 个人记录

首先可以把题意转化一下:

把一棵点带权树上的 n 个点划分为 AB 两个集合,一种划分方案的代价是:来自 A 集合的代价:\sum_{i}d_i+g(A)-\sum_{x \gets y}[w_x \le w_y];来自 B 集合的代价:\sum_{x \gets y}[w_x < w_y]。其中 x \gets y 表示 xy 的祖先, g(A)={|A| \choose 2}d_i 表示 i 到根的距离。总代价为两集合的代价之和。

全部在 B 集合的情况非常好算,考虑从 B 中逐个取点加入到 A

如果点权互不相同,则容易发现取每个点的代价为 d_i - a_i - b_i,其中 a_i 表示比 i 小的祖先个数, b_i 表示比 i 大的子孙个数。它们与 A,B 集合的状态无关,所以从小到大排序后一个个加入即可。

如果有点权相同的情况,则代价应该是 d_i - a_i - b_i - c_i,其中 c_i 表示目前 A 集合中与 i 有祖先子孙关系的点中与 i 权值相同的点的个数。这时候我们不知道 c_i 是多少,并且不知道取走 d_i - a_i - b_i - c_i 最小的点的后续转移是否更优。

不过看起来如果存在祖先子孙权值相同,先取祖先好像更优。因为 d_x - a_x \le d_y - a_y(尽管可能 a_x \le a_y,但是 d_x < d_y 有压倒性优势),并且一定有 -b_x - c_x \le -b_y - c_y,于是当前祖先一定比子孙更优。并且选择祖先的后续转移也比子孙的好,因为选择祖先能“惠及”更多的点有 -c_i 的代价。因此,选择一个点的时候与其相等祖先一定全部被选,与其相等的子孙一定全没被选。据此,c_i 还可以表示为与 i 相等的祖先的个数。

每个点的 a_i,b_i,c_i,d_i 可以用树状数组算出。

关键代码:

int fa[N], dep[N];
int a[N], b[N];
void dfs1(int cur, int faa) {
    fa[cur] = faa; dep[cur] = dep[faa] + 1;
    a[cur] = query(w[cur]);
    add(w[cur], 1);
    for (register int i = head[cur]; i; i = e[i].nxt) {
        int to = e[i].to; if (to == faa)    continue;
        dfs1(to, cur);
    }
    add(w[cur], -1);
}
void dfs2(int cur, int faa) {
    add(w[cur], 1);
    b[cur] -= (query(ltot) - query(w[cur]));
    for (register int i = head[cur]; i; i = e[i].nxt) {
        int to = e[i].to; if (to == faa)    continue;
        dfs2(to, cur);
    }
    b[cur] += (query(ltot) - query(w[cur]));
}

int id[N];
inline bool cmp(const int x, const int y) {
    return dep[x] - a[x] - b[x] < dep[y] - a[y] - b[y];
}

ll ans[N];
int main() {
    read(n);
    for (register int i =1 ; i <= n; ++i)   read(w[i]), h[i] = w[i];
    lsh();
    for (register int i =1 ; i < n; ++i) {
        int u, v; read(u), read(v);
        addedge(u, v), addedge(v, u);
    }
    dep[0] = -1;
    dfs1(1, 0);
    dfs2(1, 0);
    for (register int i = 1; i <= n; ++i)   id[i] = i;
    sort(id +1 , id + 1 + n, cmp);
    ll res = 0;
    for (register int i = 1; i <= n; ++i)   res += b[i];
    ans[n] = res;
    for (register int i = 1; i <= n; ++i) {
        int p = id[i];
        res += dep[p] - a[p] - b[p];
        ans[n - i] = res + ((1ll * i * (i - 1)) >> 1);
    }
    for (register int i = 0; i <= n; ++i)
        printf("%lld\n", ans[i]);
    return 0;
}