CF1179D Fedor Runs for President 题解

· · 题解

洛谷传送门:CF1179D Fedor Runs for President

模拟赛的时候居然自己做出来了,用的换根dp,看到题解区还没有讲到换根的,于是写篇题解介绍一下。

step 1. 问题转化:

考虑加上一条边后,新增了多少条路径。
此时原树变为了一棵基环树,将其看作是把若干子树的根顺次连接成环得到的图,当且仅当两个点分别处于不同子树的时候,其之间的路径会经过环,也就新增了一条路径。若有 k 棵子树,第 i 棵大小为 siz_i,那么路径条数为:

\begin{aligned} &\frac{n(n-1)}{2}+\frac{1}{2}\sum_{i=1}^{k}siz_i(n-siz_i)\\ =&\frac{n(n-1)}{2}+\frac{1}{2} \left(n^2-\sum_{i=1}^{k}siz_i^2 \right) \end{aligned}

式子的意义是原树路径数量+两点在不同子树中的点对数量。

观察式子发现只需要最小化 \sum_{i=1}^{k}siz_i^2,即子树大小平方和,此时我们可以将问题看作选择原树中的一条路径,最小化断开路径上的边后形成的连通块的大小的平方和。
对于这种寻找路径的问题,一种思路是像其它题解一样在路径端点的lca处拼接两条来自不同子树的路径;还有一种就是枚举节点,计算以该点为路径其中端点的答案,然后考虑用换根 dp 优化。

step 2. 换根 dp + 斜率优化

考虑怎样得到以某一个节点为路径端点时的答案,将其作为整棵树的根,设 f_u 为以 u 为根的子树中路径从 u 向下延伸时的答案,首先注意到路径一定会向下延伸直到叶子节点(siz 总和不变,份数越多平方和越小),易得转移方程:

f_u= \begin{cases} 1 & \text{u is leaf}\\ \min_{son} \{f_{son}+(siz_u-siz_{son})^2\} & \text{otherwise} \end{cases}

暴力枚举每个节点作为根即可得到 O(n^2) 做法、

接着考虑换根 dp 优化,设当前的根为 u,要将根换到 s,关键在于如何计算除去子树 s 后的 f_u。设 f_u 是从儿子 p 转移而来,观察转移方程:

\begin{aligned} f_u&=f_p+(siz_u-siz_p)^2\\ &=f_p+siz_u^2-2siz_usiz_p+siz_p^2 \end{aligned}

移项得到经典斜率优化式子:

2siz_usiz_p+f_u-siz_u^2=f_p+siz_p^2

每个儿子 son 看作点 (2siz_{son},f_{son}+siz_{son}^2),维护下凸壳,查找去掉子树 s 后的 f_u 的最优转移点就是查找斜率为 siz_u-siz_s 与下凸壳的切点,时间复杂度 O(n \log n)偷懒使用了不用动脑子的二分

代码,具体实现参考注释:

//CF1179D Fedor Runs for President
#include <bits/stdc++.h>
using namespace std;

inline int read() {
    int x = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x;
}

typedef long long ll;
const int MAXN = 5e5+5;

ll n, siz[MAXN];
ll f[MAXN], ans;

vector<int> T[MAXN];

struct convex {
    vector<int> stk;
    int top;

    inline ll X(int x) { return 2ll * siz[x]; }
    inline ll Y(int x) { return f[x] + 1ll * siz[x] * siz[x]; }
    inline long double slope(int x, int y) { return (long double) (Y(x) - Y(y)) / (X(x) - X(y)); }

    inline void push(int x) {
        if (top && X(stk[top]) == X(x) && Y(stk[top]) <= Y(x)) return;
        while (top && X(stk[top]) == X(x) && Y(stk[top]) > Y(x)) top--;//注意处理横坐标相同的情况
        while (top > 1 && slope(stk[top], stk[top - 1]) > slope(x, stk[top - 1])) top--;//查找最小值维护下凸壳
        stk[++top] = x;
    }

    inline int query(int k) {
        int L = 1, R = top - 1, ans = stk[top];
        while (L <= R) {
            int mid = (L + R) >> 1;
            if (slope(stk[mid], stk[mid + 1]) >= k) ans = stk[mid], R = mid - 1;
            else L = mid + 1;
        }
        return ans;
    }
} C;

void dfs1(int x, int fa) {
    siz[x] = 1;
    if (siz[x] == 1) return f[x] = 1, void();
    for (int son : T[x]) {
        if (son == fa) continue;
        dfs1(son, x);
        siz[x] += siz[son];
    }

    sort(T[x].begin(), T[x].end(), [](int x, int y) { return siz[x] < siz[y]; });//按siz排序,保证横坐标单调

    for (int son : T[x]) {
        if (son == fa) continue;
        f[x] = min(f[x], f[son] + (siz[x] - siz[son]) * (siz[x] - siz[son]));
    }
}

ll g[MAXN];

void dfs2(int x, int fa, ll ffa) {
    ll tf = 1e18;
    for (int son : T[x]) {
        if (son == fa) continue;
        tf = min(tf, f[son] + (n - siz[son]) * (n - siz[son]));//计算以x点为根时的答案
        g[son] = ffa + (siz[x] - siz[son]) * (siz[x] - siz[son]);
    }
    ans = min(ans, min(tf, ffa + siz[x] * siz[x]));

    if (x != 1 && T[x].size() == 1) return;

    C.stk.resize(T[x].size() + 1), C.top = 0;
    int sz = T[x].size();
    for (int k = 0; k < sz; k++) {
        int son = T[x][k];
        if (son == fa) continue;
        int p = C.query(n - siz[son]);
        if (p) g[son] = min(g[son], f[p] + (n - siz[son] - siz[p]) * (n - siz[son] - siz[p]));
        C.push(son);
    }

    C.stk.resize(T[x].size() + 1), C.top = 0;
    for (int k = sz - 1; ~k; k--) {
        int son = T[x][k];
        if (son == fa) continue;
        int p = C.query(n - siz[son]);
        if (p) g[son] = min(g[son], f[p] + (n - siz[son] - siz[p]) * (n - siz[son] - siz[p]));
        //对于每个son,计算前缀最优和后缀最优,两者再加上从父节点转移的答案一起取min
        C.push(son);
    }

    for (int son : T[x]) {
        if (son == fa) continue;
        dfs2(son, x, g[son]);
    }
}

int main() {
    n = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        T[u].push_back(v), T[v].push_back(u);
    }

    memset(f, 0x3f, sizeof(f));
    dfs1(1, 0); ans = f[1];
    dfs2(1, 0, 1e18);

    printf("%lld", n * (n - 1) / 2 + (n * n - ans) / 2);
    return 0;
}