最近公共祖先(LCA)

· · 算法·理论

众所周知,你的祖先就是你的父母,你父母的父母,你父母的父母的父母,以此类推,树上节点的祖先就是它的父节点,它父节点的父节点,它父节点的父节点的父节点。

Part 1 定义

祖先(Ancestor):一个结点到根结点的路径上,除了它本身外的结点。

公共祖先(Common Ancestor):两个或多个结点相同的祖先

最近公共祖先(Lowest Common Ancestor):两个点的最近公共祖先就是这两个点的公共祖先里面,离根最远的那个。

Part 2 性质

Part 3 求解

Way 1 朴素算法

核心思想:可以每次找深度比较大的那个点,让它向上跳。显然在树上,这两个点最后一定会相遇,相遇的位置就是想要求的 LCA。 或者先向上调整深度较大的点,令他们深度相同,然后再共同向上跳转,最后也一定会相遇。

对于一棵树,我们先预处理每一个节点的深度,用 dt 数组来存储,其深度就是其父节点深度加一,则 dt[i]=dt[u]+1 ,因为我们要寻找两个节点的最近公共祖先,所以还要存储每个节点父节点,即 fa[i]=u

接下来,对于两个结点 uv ,我们不妨设 dt[u]>dt[v] ,且两点的深度差 dep=dt[u]-dt[v] 。然后我们便让 u 持续向上跳,即 u=fa[u] ,直到两个结点深度差为0。

接着,两个结点就到了同一深度,我们可以让两个结点同时开始向上跳,知道两点相遇,而相遇的点就是两点的 LCA

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+50;
vector<int>tree[MAXN];
queue<int>q;
int n,m,s,dt[MAXN],fa[MAXN];
void scan_tree(){
    int u,v;
    scanf("%d%d",&u,&v);
    tree[u].push_back(v);
    tree[v].push_back(u);
}
void dfs_depth(int now,int past){
    for(int i:tree[now]){
        if(i==past)continue;
        fa[i]=now;
        dt[i]=dt[now]+1;
        dfs_depth(i,now);
        //TODO
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1;i<=n-1;i++)scan_tree();
    dfs_depth(s,-1);
    for(int i=1;i<=m;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        if(dt[u]<dt[v])swap(u,v);
        int depcha=dt[u]-dt[v];
        while(depcha--)u=fa[u];
        while(u!=v)u=fa[u],v=fa[v];
        printf("%d\n",u);
    }

    return 0;
}

由于需要遍历整棵树,所以时间复杂度是 O(n) ,对于P3379 【模板】最近公共祖先(LCA),会有超时的情况:提交记录

Way 2 倍增算法

核心思想:当我们在向上跳时,可能会存在跳了很长一段距离,时间很长的情况,所以我们不妨将跳跃的距离转化成2的指数的形式,定义 fa[i][j] 为从结点 i 向上跳2的 j 次方距离所到的结点,首先我们在搜索函数中会传两个参数,一个是当前节点 i ,另一个就是当前节点的父亲节点 now ,所以当前节点的直接父节点就有了,也就是 fa[i][0]=now ,因为我们函数是从根节点向下搜索的,所以我们一定在用求 fa[i][j] 节点时, f[f[i][j-1]][j-1] 肯定已经更新完成了,所以我们就可以用更新完的值去更新当前 fa 数组,我们可以将其视作 i 向上跳两次2的 j-1 次方,时间复杂度来到了 O(n log n)

fa[i][j]=fa[fa[i][j-1]][j-1]

对于这一颗树,我们在预处理时将它每一个结点向上跳2的指数倍的结点求解出来。

求解 LCA 时,由于我们存的是2的指数倍的距离,所以我们将此处的 dep 也进行二进制拆分优化,将 dep 转化成多个2的指数相加的形式,显然,当某一位为0时,在这一位上是无意义的,所以我们要判断 dep 是否为2的倍数,如果不是,则证明它的二进制形式最后一位为1,可以向上跳。因为我们每次都是在二进制形式的最后一位上运算,所以我们每次需要将 dep 除以2,来保证最后一位的更新。

因此我们得到如下的代码:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+50;
vector<int>tree[MAXN];
int n,m,s;
int fa[MAXN][20],dep[MAXN];
void dfs(int now,int past){
    for(int i:tree[now]){
        if(i==past)continue;
        dep[i]=dep[now]+1;
        fa[i][0]=now;
        int d_max=log2(dep[now]);
        for(int j=1;j<=d_max;j++){
            fa[i][j]=fa[fa[i][j-1]][j-1];
            //TODO
        }
        dfs(i,now);
        //TODO
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1;i<=n-1;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        tree[u].push_back(v);
        tree[v].push_back(u);
        //TODO
    }
    dfs(s,-1);
    for(int i=1;i<=m;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        if(dep[u]<dep[v])swap(u,v);
        int dep_cha=dep[u]-dep[v];
        int j=0;
        while(dep_cha){
            if(dep_cha&1)u=fa[u][j];
            dep_cha>>=1;
            j++;
            //TODO
        }
        int d_max=log2(n);
        for(int j=0;j<=d_max;j++){
            if(fa[u][j]!=fa[v][j]){
                u=fa[u][j];
                v=fa[v][j];
            }
            //TODO
        }
        printf("%d\n",u);
        //TODO
    }
    return 0;
}

提交记录

由于我们是以2的指数倍向上跳,可能会存在跳过最近公共祖先,而跳到了其他公共祖先的情况,由定义可得如果一个结点是另外两个结点的公共祖先,那么它的祖先节点就是这两个结点的公共祖先,所以我们在求解时,应当从上向下寻找,并且我们最后找到的是两个结点最近公共祖先的下一个节点,这也就意味着我们需要输出 f[u][0]

所以我们修改后得出:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+50;
vector<int>tree[MAXN];
int n,m,s;
int fa[MAXN][20],dep[MAXN];
void dfs(int now,int past){
    for(int i:tree[now]){
        if(i==past)continue;
        dep[i]=dep[now]+1;
        fa[i][0]=now;
        int d_max=log2(dep[now]);
        for(int j=1;j<=d_max;j++){
            fa[i][j]=fa[fa[i][j-1]][j-1];
            //TODO
        }
        dfs(i,now);
        //TODO
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1;i<=n-1;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        tree[u].push_back(v);
        tree[v].push_back(u);
        //TODO
    }
    dfs(s,-1);
    for(int i=1;i<=m;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        if(dep[u]<dep[v])swap(u,v);
        int dep_cha=dep[u]-dep[v];
        int j=0;
        while(dep_cha){
            if(dep_cha&1)u=fa[u][j];
            dep_cha>>=1;
            j++;
            //TODO
        }
        int d_max=log2(n);
        for(int j=d_max;j>=0;j--){
            if(fa[u][j]!=fa[v][j]){
                u=fa[u][j];
                v=fa[v][j];
            }
            //TODO
        }
        printf("%d\n",fa[u][0]);
        //TODO
    }
    return 0;
}

提交记录

因为两结点移至同一层次后发现它们是同一个节点,也就是说它们的公共祖先是它们其中一个节点,这时候就要直接输出了,所以我们应该特判一下

得到如下代码:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+50;
vector<int>tree[MAXN];
int n,m,s;
int fa[MAXN][20],dep[MAXN];
void dfs(int now,int past){
    for(int i:tree[now]){
        if(i==past)continue;
        dep[i]=dep[now]+1;
        fa[i][0]=now;
        int d_max=log2(dep[now]);
        for(int j=1;j<=d_max;j++){
            fa[i][j]=fa[fa[i][j-1]][j-1];
            //TODO
        }
        dfs(i,now);
        //TODO
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1;i<=n-1;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        tree[u].push_back(v);
        tree[v].push_back(u);
        //TODO
    }
    dfs(s,-1);
    for(int i=1;i<=m;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        if(dep[u]<dep[v])swap(u,v);
        int dep_cha=dep[u]-dep[v];
        int j=0;
        while(dep_cha){
            if(dep_cha&1)u=fa[u][j];
            dep_cha>>=1;
            j++;
            //TODO
        }
        if(u==v){
            printf("%d\n",u);
            continue;
        }
        else{
            int d_max=log2(n);
            for(int j=d_max;j>=0;j--){
                if(fa[u][j]!=fa[v][j]){
                    u=fa[u][j];
                    v=fa[v][j];
                }
                //TODO
            }
            printf("%d\n",fa[u][0]);
        }
        //TODO
    }
    return 0;
}

提交记录

由于我们是向上跳跃2的指数倍, log2 函数是向下取整,会少跳一层,所以我们加上1使其完整跳跃

我们又双叒叕得到了下面的代码:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+50;
vector<int>tree[MAXN];
int n,m,s;
int fa[MAXN][20],dep[MAXN];
void dfs(int now,int past){
    for(int i:tree[now]){
        if(i==past)continue;
        dep[i]=dep[now]+1;
        fa[i][0]=now;
        int d_max=log2(dep[now])+1;
        for(int j=1;j<=d_max;j++){
            fa[i][j]=fa[fa[i][j-1]][j-1];
            //TODO
        }
        dfs(i,now);
        //TODO
    }
}
int main(){
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1;i<=n-1;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        tree[u].push_back(v);
        tree[v].push_back(v);
        //TODO
    }
    dfs(s,-1);
    for(int i=1;i<=m;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        if(dep[u]<dep[v])swap(u,v);
        int dep_cha=dep[u]-dep[v];
        int j=0;
        while(dep_cha){
            if(dep_cha&1)u=fa[u][j];
            dep_cha>>=1;
            j++;
            //TODO
        }
        if(u!=v){
            printf("%d\n",u);
            continue;
        }
        else{
            int d_max=log2(n)+1;
            for(int j=d_max;j>=0;j--){
                if(fa[u][j]!=fa[v][j]){
                    u=fa[u][j];
                    v=fa[v][j];
                }
                //TODO
            }
            printf("%d\n",fa[u][0]);
        }
        //TODO
    }
    return ;
}

提交记录

OK,孩子们讲完了!

Part 4 习题

P3379 【模板】最近公共祖先(LCA)

P5002 专心OI - 找祖先

其他题我还没做,就先别做了吧