【笔寄】虚树 / P2495 题解

· · 个人记录

虚树是干嘛用的?

首先虚树是一颗被重建的树。

给你一颗树,n 个节点(1\leq n\leq10^7

此类问题就应该可以用虚树求解。

首先是一颗树:

用方框括起来的点就是被选中的点

预处理:我们先求出这颗树的 dfs 序,以及预处理一个较快(\log)求 lca 的数据结构。

我们建立虚树使用来模拟 dfs

一开始将 1 推入栈

--------
| 1
--------

然后以次类推,我们假设每次 dfs 先进入靠左的字树,一次加入 3,4,注意此时不连任何边

--------
| 1 3 4
--------

发现到底了,此时 dfs 会回溯到 3,进入 5,此时发现 5 已经不应当建到 4 的下面,尝试弹栈

注意弹栈的过程中弹一个节点就连一下此节点与新栈顶

假设此时栈顶为 x,次栈顶为 y,弹到 \operatorname{lca}(y,5)=y 时停下,其实就是没有弹(

分类讨论,此时 \operatorname{lca}(x(4),5)=\operatorname{lca}(y(3),5)=y(3),所以我们不需要额外建立点

直接连接 x(4)y(3),弹出 x,插入 5,继续!

--------
| 1 3 5
--------

然而发现又到了底部,回溯到 2,进入 7,按照上面的方案继续弹栈

弹栈的同时连接 53

此时 x=3,y=1

--------
| 1 3
--------

分类讨论,由于\operatorname{lca}(x(3),7)\not=\operatorname{lca}(y(1),7)=y(1),因此我们需要新建节点 \operatorname{lca}(x,7),也就是 2

连接 x\operatorname{lca}(x,7),弹出 x,插入 \operatorname{lca}(x,7),插入 7 就行了。

--------
| 1 2 7
--------

最后发现所有的节点都已经遍历的一遍,但是栈里面还有东西,所以把栈里面的节点全部相连。

连接 27,连接 12

至此,虚树建立完成

看起来并没有什么变化

分析一下可以发现每个节点至多只会添加一个 lca,所以节点个数至多为 2k+1 个(加上 1 号节点)

再分析一下发现时间复杂度是 O(n) 的,所以一开始预处理 dfn 序,然后将 k 个点按照 dfn 序排序后依次进行即可。

代码如下(树剖求 lca):

#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN = 250005;
// 定义图 
struct edge {
    int to, val, nxt;
} edges[MAXN << 1], nedges[MAXN << 2];
int head[MAXN], tot;
int nhead[MAXN << 1], ntot;
void add(int u, int v, int w) {
    edges[++tot].to = v; edges[tot].val = w; edges[tot].nxt = head[u]; head[u] = tot;
}
void nadd(int u, int v, int w) {
    nedges[++ntot].to = v; nedges[++ntot].val = w; nedges[tot].nxt = nhead[u]; nhead[u] = ntot;
}
// 树剖求 dfs 序和 lca 
int size[MAXN], fa[MAXN], maxson[MAXN], dep[MAXN];
void dfs1(int u, int f) {
    maxson[u] = -1;
    fa[u] = f;
    size[u] = 1;
    dep[u] = dep[f] + 1;
    for (int i = head[u]; i; i = edges[i].nxt) {
        if (edges[i].to == f) continue;
        dfs1(edges[i].to, u);
        size[u] += size[edges[i].to];
        if (maxson[u] == -1 || size[edges[i].to] > size[maxson[u]]) maxson[u] = edges[i].to;
    }
}
int top[MAXN], dfn[MAXN], times;
void dfs2(int u, int tp) {
    top[u] = tp;
    dfn[u] = ++times;
    if (~maxson[u]) dfs2(maxson[u], tp);
    for (int i = head[u]; i; i = edges[i].nxt) {
        if (edges[i].to == fa[u]) continue;
        if (edges[i].to == maxson[u]) continue;
        dfs2(edges[i].to, edges[i].to);
    }
}
int lca(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    return v;
}
// 自定义比较函数
bool cmp(int x, int y) {
    return dfn[x] < dfn[y];
} 
// 建虚树
int stack[MAXN << 1], tp;
void Build(int *List, int k) {
    tp = 1;
    stack[0] = 1;
    for (int i = 1; i <= k; i++) {
        if (lca(List[i], stack[tp - 1]) != stack[tp - 1]) {
            while (lca(List[i], stack[tp - 2]) != stack[tp - 2]) {
//              nadd(stack[tp - 2], stack[tp - 1]);
                printf("%d %d\n", stack[tp - 2], stack[tp - 1]);
                tp--;
            }
            if (lca(List[i], stack[tp - 1]) == stack[tp - 2]) {
//              nadd(stack[tp - 2], stack[tp - 1]);
                printf("%d %d\n", stack[tp - 2], stack[tp - 1]);
                tp--;
            }
            else {
                int p;
//              nadd(p = lca(List[i], stack[tp - 1]), stack[tp - 1]);
                printf("%d %d\n", p = lca(List[i], stack[tp - 1]), stack[tp - 1]);
                stack[tp - 1] = p;
            }
        }
        stack[tp++] = List[i];
    }
    for (int i = tp - 2; i >= 0; i--) {
//      nadd(stack[i], stack[i - 1]);
        printf("%d %d\n", stack[i], stack[i + 1]);
    }
}
// 结束
int lst[MAXN];
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d %d %d", &u, &v, &w);
        add(u, v, w);
        add(v, u, w);
    }
    dfs1(1, 0);
    dfs2(1, 1);
    for (int i = 1; i <= 7; i++) printf("%d ", top[i]); puts("");
    for (int i = 1; i <= 7; i++) {
        for (int j = 1; j <= 7; j++) {
            printf("%d %d %d\n", i, j, lca(i, j));
        }
    }
    int k;
    scanf("%d", &k);
    for (int i = 1; i <= k; i++) scanf("%d", &lst[i]);
    sort(lst + 1, lst + 1 + k, cmp);
    Build(lst, k);
    return 0;
}
/*
7
1 2 1
2 6 1
6 7 1
2 3 1
3 5 1
3 4 1
4
7 3 5 4
*/