【JEOI-R1】树的上色 题解

· · 个人记录

题意简述

树上有两个黑点,在每个单位时间内,每个黑点可以把自己相邻的一个白点变为黑色,求把整棵树所有点变为黑色的最短时间。

解题思路

看到这道题的第一反应是:P2018。(可以看看这道题)

先来一个简化版问题:

树上有一个黑点,在每个单位时间内,每个黑点可以把自己相邻的一个白点变为黑色,求把整棵树所有点变为黑色的最短时间。

f_i 表示把以 i 为根的子树全部染为黑色的最短时间。

对于叶子结点, f_i=1

对于非叶子结点,当仅有 1 颗子树时,显然 f_i=f_{k \in subtree(x)}+1 。当子树颗数不止 1 颗我们需要考虑给子节点染色的顺序。

引理

按照 f_{k \in subtree(i)} 从大到小顺序染色即可。

证明

该贪心思路可由“邻项交换”证明。

f_{k_1 \in subtree(i)}f_{k_2 \in subtree(i)} 满足 f_{k1}>f_{k2} 且有且仅有 k_1, k_2 两颗子树

若先对 k_1 染色,则 f_i=\max (f_{k_1},f_{k_2}+1)+1 (第一个 +1k_2 等待 k_1 根节点先染色的单位时间,第二个 +1 是父节点 f_i 染色的时间)。

假设交换顺序:先对 k_2 染色,则 f_i=\max (f_{k_1}+1,f_{k_2})+1 (含义同上)。

因此该证明变为比较:(前面的是交换前,后面的是交换后)

\max (f_{k_1},f_{k_2}+1)+1$ 和 $\max (f_{k_1}+1,f_{k_2})+1

同时减去 1 ,变为比较:

\max (f_{k_1},f_{k_2}+1)$ 和 $\max (f_{k_1}+1,f_{k_2})

因为 f_{k1}>f_{k2} ,可得 f_{k_1}+1>f_{k2},变为比较:

\max (f_{k_1},f_{k_2}+1)$ 和 $f_{k_1}+1

f_{k1}>f_{k2} 还可得:

f_{k1}+1>f_{k2}+1 \\ f_{k1}+1>f_{k1} \end{cases}

因此可得:

\max (f_{k_1},f_{k_2}+1)<f_{k_1}+1

\max (f_{k_1},f_{k_2}+1)+1<\max (f_{k_1}+1,f_{k_2})+1

所以可知,在交换前更优,即要先染 f_{k \in subtree(i)} 更大的点,该结论可以推广到子树颗数大于等于 2 的情况下。

证毕。

根据引理可知,对于非叶子结点,f_i=\max \{ f_{k \in subtree(i)}+w_k \}+1 ,其中 0 \leq w_k < size(subtree(i))size(subtree(i)) 表示 i 的子树颗数, w_k 表示选择子树 k 前已经染色了多少颗子树。

至此,我们完成了简化版问题的求解。

回到本题,本题题目中是先把两个节点染成了黑色,我们需要把问题转化为上面的简化版问题。

(这里可以先思考一下在往下看)

方法:找到两个黑点间的所有树边,枚举选择其中一条树边断开,分为两颗子树,两个子树的根节点分别为两个黑点,时间复杂度 O(n^2) ,而数据范围允许的时间复杂度为 O(nlogn)

所以我们枚举改为 O(logn) 的选择,自然可以想到二分。

对于二分,我们需要进行单调性证明。

分别设两个黑点的子树为 a,b

由于是一颗树形结构,所以两个黑点间的路径是唯一的。

对于断边操作,相当于分别给子树 a,b 增加了一颗链型的子树。

由于 ans=\max(f_a,f_b) ,这里的二分我们可以选择依赖的是 f_a 的单调性(即找到一个断边使两颗子树答案最大值最小)。

随着选择的边远离 a 点,f_a 呈非严格单调递增(更详细的说,当增加的节点个数小于一个定值,答案不变,超过后答案呈单调递增),f_b 非严格单调递减。由此找到了二分的两段性。

IL bool check(reg int p) {
    vis[s[p]] = vis[s[p] ^ 1] = 1;
    solve(a), solve(b);
    vis[s[p]] = vis[s[p] ^ 1] = 0;
    return f[a] >= f[b];
}

另一种理解

我们可以直接建立一个 0 号虚点,连接两个黑点,这颗树会变为一颗基环树。接下来我们可以进行基环树DP,由于实际染黑的只有环上 2 个结点,不能采取“广义根”方法来解决。所以我们采用和上方相同的方法,二分环上的断边即可。(代码其实是相同的,0 号节点不计入答案因此实际可以不建)

部分分

(部分分还是很足的,优化的暴力大概 16 分,发现树形DP模型即可得 64 分,加上二分就是正解)

子任务编号 测试点 n= 特殊性质 解法
0 1 ~ 2 16 暴力枚举/搜索即可
1 3 ~ 4 30 暴力枚举/搜索 + 剪枝
2 5 ~ 6 10^5 A 分析链的答案性质,推导 f_af_b ,然后把两点之间的链平均分配,使得 \mid f_a-f_b \mid 最小即可
3 7 ~ 12 10^4 上文所述的枚举选择其中一条树边断开
4 13 ~ 16 5\times 10^5 B 无需枚举,分别求解两颗子树即可
5 17 ~ 25 5\times 10^5 正解

特殊性质 A :保证输入数据构成一条链。

特殊性质 B :保证树上 fa_a=bfa_b=afa_x 指树上 x 的父节点)。

代码

std:

#include <bits/stdc++.h>
#define reg register
#define IL inline
#define N 500500
IL int read() {
    reg int x = 0;
    reg char ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
    return x;
}

int n, a, b;
int first[N], nxt[N + N], to[N + N], num = 1;

IL void add(reg int u, reg int v) { nxt[++num] = first[u], to[num] = v, first[u] = num; }
IL void adde(reg int u, reg int v) { add(u, v), add(v, u); }

int in[N], s[N], tot;

void dfs(reg int u, reg int fa = 0) {
    for (reg int i = first[u], v; i; i = nxt[i])
        if ((v = to[i]) ^ fa)              //不走父亲节点
            in[v] = i, dfs(v, u);
}

bool vis[N + N];
std::vector<int> g[N];
int f[N];

IL void cmax(reg int &x, reg int y) { x < y ? x = y : 0; }           //注意这里是引用

void solve(reg int u, reg int fa = 0) {           //对子树 u 的 f 值进行递归计算
    g[u].clear(), f[u] = 0;
    for (reg int i = first[u], v; i; i = nxt[i])
        if (!vis[i] && (v = to[i]) ^ fa)
            solve(v, u), g[u].push_back(-f[v]);
    std::sort(g[u].begin(), g[u].end());   //对子树的 f 值进行排序
    for (reg int i = 0; i < g[u].size(); ++i) cmax(f[u], i + 1 - g[u][i]);
}

IL bool check(reg int p) {
    vis[s[p]] = vis[s[p] ^ 1] = 1;     //标记边断开 
    solve(a), solve(b);             //分别计算两颗子树
    vis[s[p]] = vis[s[p] ^ 1] = 0;     //清除标记 
    return f[a] >= f[b];
}

int main() {

    n = read(), a = read(), b = read();
    for (reg int i = n; --i; adde(read(), read()))   //读入,建双边 
        ;
    dfs(a);    //从a点dfs建立有向图(树) 
    for (reg int x = b; x; s[++tot] = in[x], x = to[in[x] ^ 1])  //反向从b点走到a,记录他们之间的边
        ;
    reg int l = 1, r = tot, mid, v1 = 2e9, v2 = 2e9;
    while (l <= r) {                      //二分断开的边 
        mid = l + r >> 1;
        if (check(mid))
            l = mid + 1;
        else
            r = mid - 1;
    }
    if (l <= tot) {
        vis[s[l]] = vis[s[l] ^ 1] = 1;
        solve(a), solve(b), v1 = f[b];
        vis[s[l]] = vis[s[l] ^ 1] = 0;
    }
    if (l > 1) {
        --l;
        vis[s[l]] = vis[s[l] ^ 1] = 1;
        solve(a), solve(b), v2 = f[a];
        vis[s[l]] = vis[s[l] ^ 1] = 0;
    }
    return printf("%d", v1 < v2 ? v1 : v2), 0;  //输出最小值 
}