LCA 最近公共祖先

· · 个人记录

一、简介

1、内容

LCA 算法(英文全称:Least Common Ancestors),即 最近公共祖先。具体的,给定一棵有根树,若节点 z 满足既是 x 的祖先,又是 y 的祖先,那么我们说 zxy 的公共祖先。在 xy 的所有公共祖先中,深度最深的一个称为 xy 的最近公共祖先。

2、用途

二、算法流程

朴素算法

我们不妨以下面这棵树为例:

举个例子:节点 75 的最近公共祖先是 2。我们可以通过一遍深度优先搜索来遍历每一个节点的深度。我们从深度较大的那一个节点出发,先让其向上查找祖先,并让找到的祖先与深度较浅的那一个节点深度相同,随后让两个节点同步向上查找,直到找到相同的节点。

算上查询次数,这个算法的时间复杂度为 O(nq)

倍增求 LCA

同样一棵树,我们继续深入研究。我们发现,上面的朴素算法之所以慢,是因为我们每一次只会一步一步慢慢向上查找。我们不妨将其与我们之前学过的二进制分解相结合,也就是利用倍增的思想,尝试让时间复杂度压缩至 \log 级别。

这类似于一个 RMQ。同样利用了动态规划的思想,将数据转移到一棵树上。令 v 为当前决策进行到的节点,uv 的父节点,i 表示我们的状态记录的是与 v 相距 2^i个节点的祖先。我们很容易可以得到:

随后我们根据朴素算法的思路,分成两个步骤解决问题:

  1. 同样先找到 x 的祖先中与 y 相同深度的那一个。区别在于,每次贪心地选择,利用二进制分解定理,每一次尽可能向根节点走更多的步数( 2^i 步 ),但是却不能比 y 的深度浅,这样我们可以更加快速地求出我们需要的 x 的祖先。
  2. 此时 xy 的深度相同,我们让 xy 继续按照上述倍增的方法继续查找,找到距离最远的且互不相同的 xy 的祖先节点。最后找到的 x(或 y)的父节点,即为我们要求的 LCA

参考代码

//倍增求 LCA
#include<bits/stdc++.h>
using namespace std;
int n, m, s, x, y, dep[500005], f[500005][25];
vector<int> G[500005];
inline int read(){
    int s = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if (c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        s = s * 10 + c - '0';
        c = getchar();
    }
    return s * f;
}
inline void write(int x){
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
inline void init(int u, int fa){
    dep[u] = dep[fa] + 1;
    for (int i = 1; i <= 20; i++)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for (auto &v : G[u]){
        if (v == fa) continue;
        f[v][0] = u;
        init(v, u);
    }
}
inline int LCA(int x, int y){
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 20; i >= 0; i--){
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
        if (x == y) return x;
    }
    for (int i = 20; i >= 0; i--)
        if (f[x][i] != f[y][i])
            x = f[x][i], y = f[y][i];
    return f[x][0];
}
int main(){
    n = read(), m = read(), s = read();
    for (int i = 1; i < n; i++){
        x = read(), y = read();
        G[x].push_back(y);
        G[y].push_back(x);
    }
    init(s, 0);
    while(m--){
        x = read(), y = read();
        write(LCA(x, y));
        putchar('\n');
    }
    return 0;
}

以上算法就是我们常用的 倍增求 LCA 算法,其总体时间复杂度为 O(q\log n)

但是实际上我们还可以用更加优化的算法。

Tarjan 算法

看到倍增算法,我们发现它的时间复杂度在 \log 级别。所以一旦 nq 超过了一个较大的数(例如 10^6),那么就很有可能 TLE

我们再次回顾朴素算法,我们可以把向上查找的过程理解为不断向上标记的过程。而我们可以考虑一个离线算法,把每一个问题存为节点上需要处理的信息。这样统一读入,统一输出的离线算法,时间复杂度控制在 O(n+q)

(当然只是理论复杂度,因为常数等等原因,实际上跑起来可能还不如倍增)

具体算法流程,我们还是通过分步骤解释。

  1. 根据深度优先搜索的搜索顺序,一个节点要进行回溯,那么它的所有子节点一定已经回溯完毕。我们可以对每一个节点进行分类标记。没有经过的节点,我们标记为 0;已经经过但没有回溯的节点,我们记为 1;已经经过且已经回溯的节点,我们记为 2
  2. 两个节点有公共祖先,那么这两个节点一定处于这个公共祖先的子树上。我们只要知道两个节点是否已经在同一子树上,就可以知道它们的最近公共祖先。
  3. 但是,“判断是否在同一子树”这个问题,很明显我们不能直接暴力。其实我们可以用并查集优化。我们在递归回溯时,把一个子节点合并到它的父节点所在的并查集。不难发现,合并时父节点的标记一定为 1,并且一定是这个集合的根节点。在一个节点遍历完所有子树后,我们调用关于这个节点的询问集。如果询问的两个节点都已经遍历过,那么它们一定存在公共祖先,即并查集的根节点,直接查找即可。注意并查集要压缩路径

参考代码

//Tarjan的LCA算法
#include<bits/stdc++.h>
using namespace std;
int n, m, s, k, x, y;
int fa[5000005], dep[5000005], v[5000005], lca[5000005], ans[5000005];
bool bk[5000005];
vector<int> G[5000005];
vector<int> Q[5000005], QI[5000005];
inline int read(){
    int s = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if (c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        s = s * 10 + c - '0';
        c = getchar();
    }
    return s * f;
}
inline void write(int x){
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
inline void add(int x, int y, int id){
    Q[x].push_back(y), QI[x].push_back(id);
    Q[y].push_back(x), QI[y].push_back(id);
}
inline int get(int x){
    if (x == fa[x]) return x;
    return fa[x] = get(fa[x]);
}
inline void tarjan(int x){
    v[x] = 1;
    for (auto &y : G[x]){
        if (v[y]) continue;
        dep[y] = dep[x] + 1;
        tarjan(y);
        fa[y] = x;
    }
    for (int i = 0; i < Q[x].size(); i++){
        y = Q[x][i];
        int id = QI[x][i];
        if (v[y] == 2){
            int t = get(y);
            if (ans[id] > dep[x] + dep[y] - 2 * dep[t]){
                lca[id] = t;
                ans[id] = dep[x] + dep[y] - 2 * dep[t];
            }
        }
    }
    v[x] = 2;
}
int main(){
    n = read(), m = read();
    for (int i = 1; i <= n; i++) fa[i] = i;
    for (int i = 1; i < n; i++){
        x = read(), y = read();
        G[x].push_back(y);
        G[y].push_back(x);
    }
    for (int i = 1; i <= m; i++){
        x = read(), y = read();
        if (x == y) ans[i] = 0, lca[i] = x;
        else{
            add(x, y, i);
            ans[i] = 1 << 30;
        }
    }
    tarjan(1);
    for (int i = 1; i <= m; i++) write(lca[i]), putchar('\n');
    return 0;
}

观察上述代码,你会发现:Tarjan求LCA的算法更适合直接求两点间的距离

欧拉序 RMQ 的 LCA 算法

Tarjan 的 LCAji 算法虽然理论时间复杂度低,但是常数很大。下面我们一起来看一种更加巧妙的在线算法,它有相同的时间复杂度,并且常数还比 Tarjan 小。

我们之前的算法,都是直接在一棵树上进行查找。那么我们能不能考虑一种新的思想,把树形结构转化为线性结构并且按照 ST 表的思路进行查找呢?

我们引入一个新的数学方法--求一棵树的欧拉序。看下面这一棵树:

欧拉序在程序中,可以通俗地理解为:深度优先遍历一棵树时,依次走过的节点的编号,包括回溯时重新走回的父节点

例如上面这棵树,它的其中一个欧拉序为:1 2 5 2 6 2 1 3 1 4 7 9 7 4 8 10 8 4 1

我们不妨设节点 i 在这棵树的欧拉序中,第一次出现的位置为 pos_i。设欧拉序为a。那么,两个点 xy(事先保证 pos_x \leq pos_y) 的最近公共祖先 z,满足:

为什么呢?我有一个不太严谨的方法帮助思考。xy 的路径是一条链,而你可以把这条链看成一个函数的形状。然后你会发现,这条链的形状类似于一个二次函数或一次函数图像,函数的顶点坐标就是我们要求的 LCA

这个数值我们可以用 ST 表来维护,也就是将原本的 LCA 问题,转化为了 RMQ 查询最小值位置问题。至此,算法结束。

参考代码

#include<bits/stdc++.h>
using namespace std;
int n, m, tot, x, y, dep[200005], dfn[200005], pos[200005], st[200005][25], lg[200005], Ans;
vector<int> G[100005];
inline int read(){
    int s = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if (c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        s = s * 10 + c - '0';
        c = getchar();
    }
    return s * f;
}
inline void write(int x){
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
inline void dfs(int u, int fa, int d){
    dfn[++tot] = u;
    pos[u] = tot;
    dep[tot] = d;
    for (auto &v : G[u]){
        if (v == fa) continue;
        dfs(v, u, d + 1);
        dfn[++tot] = u;
        dep[tot] = d;
    }
} 
inline void RMQ() {
    for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i <= tot; i++) st[i][0] = i;
    for (int j = 1; j <= lg[tot]; j++){
        for (int i = 1; i + (1 << j) - 1 <= tot; i++){
            int k1 = st[i][j - 1], k2 = st[i + (1 << j - 1)][j - 1];
            st[i][j] = dep[k1] < dep[k2] ? k1 : k2;
        }
    }
}
inline int LCA(int l, int r){
    int s = lg[r - l + 1];
    int k1 = st[l][s], k2 = st[r - (1 << s) + 1][s];
    return dep[k1] < dep[k2] ? dfn[k1] : dfn[k2];
} 
int main(){
    n = read(), m = read();
    for (int i = 1; i < n; i++){
        x = read(), y = read();
        G[x].push_back(y);
        G[y].push_back(x);
    }
    dfs(1, 0, 1);
    RMQ();
    while(m--){
        x = read() ^ Ans, y = read() ^ Ans;
        Ans = LCA(min(pos[x], pos[y]), max(pos[x], pos[y]));
        write(Ans);
        putchar('\n');
    }
    return 0;
}

三、例题

【例一】点的距离

给定一棵 n 个点的树,Q 个询问,每次询问点 x 到点 y 两点之间的距离。

数据范围:1 \leq n \leq 10^51 \leq x,y \leq n

分析

我们直接求出 xy 两点的最近公共祖先,然后可以推出 xy 两点的距离:

参考代码

#include<bits/stdc++.h>
using namespace std;
int n, m, s, x, y, dep[100005], f[100005][25];
vector<int> G[100005];
inline int read(){
    int s = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if (c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        s = s * 10 + c - '0';
        c = getchar();
    }
    return s * f;
}
inline void write(int x){
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
inline void init(int u, int fa){
    dep[u] = dep[fa] + 1;
    for (int i = 1; i <= 20; i++)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for (auto &v : G[u]){
        if (v == fa) continue;
        f[v][0] = u;
        init(v, u);
    }
}
inline int LCA(int x, int y){
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 20; i >= 0; i--){
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
        if (x == y) return x;
    }
    for (int i = 20; i >= 0; i--)
        if (f[x][i] != f[y][i])
            x = f[x][i], y = f[y][i];
    return f[x][0];
}
int main(){
    n = read();
    for (int i = 1; i < n; i++){
        x = read(), y = read();
        G[x].push_back(y);
        G[y].push_back(x);
    }
    init(1, 0);
    m = read();
    while(m--){
        x = read(), y = read();
        write(dep[x] + dep[y] - dep[LCA(x, y)] * 2); 
        putchar('\n');
    }
    return 0;
}

【例二】天天爱跑步

小c 同学认为跑步非常有趣,于是决定制作一款叫做《天天爱跑步》的游戏。《天天爱跑步》是一个养成类游戏,需要玩家每天按时上线,完成打卡任务。

这个游戏的地图可以看作一一棵包含 n 个结点和 n-1 条边的树,每条边连接两个结点,且任意两个结点存在一条路径互相可达。树上结点编号为从 1n 的连续正整数。

现在有 m 个玩家,第 i 个玩家的起点为 s_i,终点为 t_i。每天打卡任务开始时,所有玩家在第 0 秒 同时从自己的起点出发,以每秒跑一条边的速度,不间断地沿着最短路径向着自己的终点跑去,跑到终点后该玩家就算完成了打卡任务。 (由于地图是一棵树,所以每个人的路径是唯一的)。小c 想知道游戏的活跃度,所以在每个结点上都放置了一个观察员。在结点 j 的观察员会选择在第 w_j 秒观察玩家,一个玩家能被这个观察员观察到当且仅当该玩家在第 w_j 秒也正好到达了结点 j 。小c 想知道每个观察员会观察到多少人?

注意:我们认为一个玩家到达自己的终点后该玩家就会结束游戏,他不能等待一 段时间后再被观察员观察到。 即对于把结点 j 作为终点的玩家:若他在第 w_j 秒前到达终点,则在结点 j 的观察员不能观察到该玩家;若他正好在第 w_j 秒到达终点,则在结点 j 的观察员可以观察到这个玩家。

分析

首先来分析一个观察员能够看到某一个玩家的条件。不妨设有一个观察员在点 x 上,那么,一个人要想经过 x,我们可以把它分成两种情况:

根据以上情况,我们可以把一个人的移动路径拆分成两条链,一条由 S_iLCA(S_i,T_i),另一条由 T_iLCA(S_i,T_i) 的一个子节点。

对于情况一,有

对于情况二,可以把它拆成两条路径之和,需要满足的条件是

对于任意的 x,我们只需要找到若干条经过以它为根节点的子树的路径,满足的上述条件的路径数量。不难发现我们可以用一个桶来维护上述权值的数量,并且桶内权值的数量是随着遍历的节点而不断变化的。设两个桶分别为 ab,我们在遍历到 S_i 时,需要 ++a[dep[S_i]],同理在遍历到 T_iLCA(S_i,T_i) 以及 fa[LCA(S_i,T_i)] 这些结点时,都需要对一些确定的权值进行修改。如果直接开二维数组,n^2 的空间直接炸掉。怎么办呢?

联想到 Tarjan 算法,我们将每一个询问当作结点的参数处理。本题亦可这么想。我们事先处理在遍历到某一个结点时,桶内大小发生变化的权值。这样的话,当我们搜索遍历到这个节点时,直接修改这些发生变化的权值即可。

如此想来,最终答案的表达式为:

考虑到 dep[i] 可能小于 w[i],我们在处理时直接将这些值统一加上一个 n 即可。

但是这样就结束了吗?

显然没有。

我们在处理完一条路径之后,是无法将桶内的元素清空的。因此,一条路径的贡献可能会影响到另一条不相干的路径。这时候它的贡献就会比实际值大。其实很容易就能排除这种情况。一个结点真正的答案总数,其实可以表示为--桶内目标权值在遍历前后大小的变化量。

所以我们开两个遍历储存一下遍历前的桶内权值的数量即可。

参考代码

#include<bits/stdc++.h>
using namespace std;
int n, m, x, y, w[300005];
int fa[300005], dep[300005], f[300005][25], ans[300005];
int a[600005], b[600005];
vector<int> G[300005], a1[300005], a2[300005], b1[300005], b2[300005];
inline int read(){
    int s = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if (c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        s = s * 10 + c - '0';
        c = getchar();
    }
    return s * f;
}
inline void write(int x){
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
inline void dfs(int u, int x){
    fa[u] = x;
    dep[u] = dep[x] + 1;
    for (int i = 1; i <= 20; i++)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for (auto &v : G[u]){
        if (v == x) continue;
        f[v][0] = u;
        dfs(v, u);
    }
}
inline int LCA(int x, int y){
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 20; i >= 0; i--){
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
        if (x == y) return x;
    }
    for (int i = 20; i >= 0; i--)
        if (f[x][i] != f[y][i])
            x = f[x][i], y = f[y][i];
    return f[x][0];
}
inline void dfs1(int u, int fa){
    int f1 = a[dep[u] + w[u]], f2 = b[dep[u] - w[u] + n];
    for (auto &v : G[u]){
        if (v == fa) continue;
        dfs1(v, u);
    }
    for (int i = 0; i < a1[u].size(); i++) a[a1[u][i]]++;
    for (int i = 0; i < a2[u].size(); i++) a[a2[u][i]]--;
    for (int i = 0; i < b1[u].size(); i++) b[b1[u][i]]++;
    for (int i = 0; i < b2[u].size(); i++) b[b2[u][i]]--;
    ans[u] = (a[dep[u] + w[u]] - f1) + (b[dep[u] - w[u] + n] - f2);
}
signed main(){
    n = read(), m = read();
    for (int i = 1; i < n; i++){
        x = read(), y = read();
        G[x].push_back(y);
        G[y].push_back(x);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++)
        w[i] = read();
    for (int i = 1; i <= m; i++){
        x = read(), y = read();
        int lca = LCA(x, y);
        a1[x].push_back(dep[x]);
        a2[fa[lca]].push_back(dep[x]);
        b1[y].push_back(dep[lca] * 2 - dep[x] + n);
        b2[lca].push_back(dep[lca] * 2 - dep[x] + n);
    }
    dfs1(1, 0);
    for (int i = 1; i <= n; i++) write(ans[i]), putchar(' ');
    return 0;
}

【例三】运输计划

L 国有 n 个星球,还有 n-1 条双向航道,每条航道建立在两个星球之间,这 n-1 条航道连通了 L 国的所有星球。

小 P 掌管一家物流公司, 该公司有很多个运输计划,每个运输计划形如:有一艘物流飞船需要从 u_i 号星球沿最快的宇航路径飞行到 v_i 号星球去。显然,飞船驶过一条航道是需要时间的,对于航道 j,任意飞船驶过它所花费的时间为 t_j,并且任意两艘飞船之间不会产生任何干扰。

为了鼓励科技创新, L 国国王同意小 P 的物流公司参与 L 国的航道建设,即允许小 P 把某一条航道改造成虫洞,飞船驶过虫洞不消耗时间。

在虫洞的建设完成前小 P 的物流公司就预接了 m 个运输计划。在虫洞建设完成后,这 m 个运输计划会同时开始,所有飞船一起出发。当这 m 个运输计划都完成时,小 P 的物流公司的阶段性工作就完成了。

如果小 P 可以自由选择将哪一条航道改造成虫洞, 试求出小 P 的物流公司完成阶段性工作所需要的最短时间是多少?

分析

很显然这是一个树形结构。

考虑每一个运输计划,两点之间的时间我们可以直接用 LCA 算法求出来。

假设我们要求的最终答案为 ans,那么,对于每一份时间大于 ans 的运输计划,我们肯定要找到一条公共边 w,使得所有运输计划中时间最长的一个(记作 maxd),减去 w 后的时间 t\leq ans

不难发现答案具有单调性,所以我们可以用二分找到最终的答案。

问题在于,我们如何快速找出所有不满足条件的运输计划共同经过的边呢?我们可以将差分算法引用到树上。定义参数 sum,在每个运输计划的起点和终点处 ++sum,在两点的 LCAsum-=2。则对于一个以 x 为根节点的子树,这棵树上 sum 的总和,就是 x 到其父节点的边被经过的次数。再找到经过次数与不满足条件的运输计划数相等的边(也就是所有不满足条件运输计划经过的公共边中)长度最大的一条,记作 w。则当 maxd-w\leq ans 时,ans 为一个符合条件的答案。

例二和例三都运用了“树上差分”算法,这种算法常常与 LCA一同使用。