浅谈LCA

· · 题解

传送门 P3379 【模板】最近公共祖先

写在前面

LCA是图论中非常重要的算法之一。它的应用范围非常广泛,在NOIP中考察的非常多,代码容易实现,实用性很高,是 蒟蒻 入门必备算法。

作为一个刚学LCA蒟蒻 ,我决定写一篇博客来总结这个算法,希望在写博客的过程中更深入地了解和掌握这个算法。

算法内容

#### 在以下这个图中,$2$和$5$的$LCA$是$3$,而$1$和$5$的$LCA$为$4$。![](https://cdn.luogu.com.cn/upload/pic/65577.png) 看到这里,大家对$LCA$的概念应该了解得很清楚了。那么问题来了,这个算法应该如何实现呢? ------------ ## Ⅰ 暴力搜索 最容易想到的,也最容易实现的,是对于每个节点,用 $O(n)$的时间维护一个深度数组 $depth[]$ ,然后从每个询问的点遍历其祖先,两个询问点的公共祖先中 $depth[]$ 最小的即为答案。 每次查询的平均时间复杂度为$O(logn)$,即该树的深度,最坏时间复杂度为$O(n)$,对于本题可能会$TLE$。 ------------ ## Ⅱ 优化算法 #### 既然暴力会$TLE$,那么就让我们考虑一下优化算法。 ------------ ### $1^o$ 倍增求$LCA

倍增在信息学中是一种很重要的思想。在找寻询问点的祖先时,可以用到倍增的思想。DFS维护每个节点的i的向上 2^j个节点f[i][j],这样在向上寻找询问点的祖先时能得到一个logn级别的优化。由1,2,4,8,...,2^n等数能组成1-2^n所有数的性质可知,该优化是可行的。

如下图f[11][0]=10,f[11][1]=6,f[f[11][0]][0]=f[11][1]=6

2^n=2^{n-1}+2^{n-1}可得 \large\text{f[i][j]=f[f[i][j-1]][j-1]} 这样一个玄学的倍增式子。

具体代码实现

void dfs(int u,int father){
    Dep[u]=Dep[father]+1;//维护深度
    for(int i=0;i<=19;i++)
        f[u][i+1]=f[f[u][i]][i];//倍增递推式
    for(int e=first[u];e;e=nxt[e]){
        int v=go[e];
        if(v==father)
            continue;
        f[v][0]=u;
        dfs(v,u);//继续进行搜索
    }
}
int lca(int x,int y){
    if(Dep[x]<Dep[y])
        swap(x,y);
    for(int i=20;i>=0;i--){
        if(Dep[f[x][i]]>=Dep[y])
             x=f[x][i];
        if(x==y)
            return x;
    }//向上搜到同一深度
    for(int i=20;i>=0;i--){
        if(f[x][i]!=f[y][i]){
            x=f[x][i];
            y=f[y][i];
        }
    }//同时倍增向上寻找公共祖先
    return f[x][0];
}

完整代码请见文末

2^o ST表求LCA

这是一种和以上两种算法思路不同的算法,它和tarjan一样是基于DFS的,同时也是离线的,在这个方面上,倍增的应用范围更为广泛。让我们来看看这种算法是如何求解LCA的。

在一次DFS,也就是树的先序遍历+回溯中,维护每个节点的深度depth[]DFSdfn[],以及每个节点在DFS序中dfn[]出现最早的位置fst[]

1是图4dfn[]序列,图2是图4depth[]序列,图3是图41-nfst[]数组。

比如我们现在要求15LCA,就在dfn[]中找到fst[1]=2fst[5]=7两个节点,然后在[2,7]这个区间对应的depth[]最小的值,也就是4,即为答案。

为什么这个方法可行呢?因为DFS深度优先的性质保证了在搜完所有x的子节点之前不会访问到比xdepth[]更小的点,从而保证在dfn[]序列的这个区间段中不会出现比LCAdepth[]更小的点。

区间中深度最小的点可以用ST表来维护,O(nlogn)ST表,每次查询时间复杂度为O(1),对于m>n时,时间复杂度比LCA更为优秀。

具体代码实现

inline void dfs(int cur,int dep){
    fst[cur]=++cnt;
    dfn[cnt]=cur;
    depth[cnt]=dep+1;//维护三个数组
    for(int i=0;i<g[cur].size();i++){
        int t=g[cur][i];
        if(fst[t]==0){
            dfs(t,dep+1);//先序遍历
            dfn[++cnt]=cur;
            depth[cnt]=dep+1;//回溯
        }
    }
}
inline void ST_init(){
    for(int i=1;i<=cnt;i++){
        st[i][0]=i;//初始化ST表
    }
    int a,b;
    for(int j=1;j<=lg[cnt];j++){
        for(int i=1;i+(1<<j)-1<=cnt;i++){//DP求取ST表
            a=st[i][j-1];
            b=st[i+(1<<j-1)][j-1];
            if(depth[a]<depth[b])
                st[i][j]=a;
            else
                st[i][j]=b;
        }
    }

}
inline int lca(int x,int y){//询问
    x=fst[x];y=fst[y];
    if(x>y)
        swap(x,y);
    int k=lg[y-x];//取断点
    int a=st[x][k];
    int b=st[y-(1<<k)+1][k];
    if(depth[a]<depth[b])//查询区间最小值
        return dfn[a];
    else
        return dfn[b];
} 

完整代码请见文末

3^o tarjanLCA

相比于倍增而言,$tarjan$和$ST$表都是离线的做法,但$tarjan$时间复杂度为$O(n+m)$,比倍增和$ST$表都更加优秀。 $tarjan$做法如下:先读入所有的询问,然后在$DFS$时,维护一个$vis[]$数组,标记搜索了哪些询问点,如果关于$x$的询问$y$已经搜索过,则$LCA(x,y)=find(x)$,也就是$x$的祖先(这里写$y$也没问题)。$DFS$完以$rt$为根的子树后,用并查集将$rt$与它的父亲合并。 $tarjan$算法同样利用了$DFS$的性质,在搜索完以$rt$为根的子树前,所有$rt$的后代的祖先都是$rt$。这样就保证了$tarjan$算法的正确性。 #### 具体代码实现 ```cpp int find(int x){//并查集操作 if(fa[x]==x) return x; else return fa[x]=find(fa[x]); } void tarjan(int u){ vis[u]=1;//标记访问 for(auto qid:q[u]){//c++11,枚举所有与u有关的询问 if(query[qid].x==u) if(vis[query[qid].y]) query[qid].lca=find(query[qid].y); else if(vis[query[qid].x]) query[qid].lca=find(query[qid].x); } for(auto v:g[u]){ if(vis[v]) continue; tarjan(v);//DFS u的子节点 fa[v]=u;//合并并查集 } } ``` #### 完整代码请见文末 ------------ ### $4^o$ 树链剖分求$LCA

树剖的算法了解一下就行,看不懂也没有关系,其实我个人认为这题用树剖是大材小用。其实是我太菜不会写树剖

这里简单讲一下树剖的思路,如果两个询问点在同一条链上,其LCA即为深度较小的点,如果两个询问点不在同一条链上,就把深度较大的点跳到它所在的重链顶端节点的父亲节点。

给出伪代码

int lca(int x,int y){
    if x的深度大于y
        swap(x,y);
    if x与y在同一链上
        return x;
    else
        y=father[top[y]];
        return lca(x,y);
}

文章结尾

有一件很重要的事情跟大家说:洛谷的模板时间卡的太死,记得开O2优化。

TLE惨案

倍增求LCA

#include<bits/stdc++.h>
using namespace std;
const int maxn=500050;
int n,m,q,x,y,head[maxn<<1],tot,dep[maxn],f[maxn][21];
struct node{
    int to,nxt;
}e[maxn<<1];
inline void add(int u,int v){
    e[++tot].to=v;
    e[tot].nxt=head[u];
    head[u]=tot;
}
inline void dfs(int u,int pre){
    dep[u]=dep[pre]+1;
    f[u][0]=pre;
    for(int i=1;i<=20;i++)
        f[u][i]=f[f[u][i-1]][i-1];
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v!=pre){
            f[v][0]=u;
            dfs(v,u);
        }
    }
}
inline int lca(int x,int y){
    if(dep[x]<dep[y])
        swap(x,y);
    for(int i=20;i>=0;i--){
        if(dep[f[x][i]]>=dep[y])
             x=f[x][i];
        if(x==y)
            return x;
    }
    for(int i=20;i>=0;i--)
        if(f[x][i]!=f[y][i])
            x=f[x][i],y=f[y][i];
    return f[x][0];
}
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
    return x*f;
}
int main(){
    n=read();q=read();m=read();
    for(int i=1;i<n;i++){
        x=read();y=read();
        add(x,y);add(y,x);
    }
    dfs(m,0);
    while(q--){
        x=read();y=read();
        printf("%d\n",lca(x,y));
    }
    return 0;
}

ST表求LCA

#include<bits/stdc++.h>
using namespace std;
const int maxn=5e5+5;
const int maxbit=22;
vector<int> g[maxn];
int dfn[maxn*2];
int depth[maxn*2];
int lg[maxn*2];
int st[maxn*2][maxbit];
int fst[maxn];
inline int read(){
    char ch=getchar();
    int x=0,f=1;
    while((ch>'9'||ch<'0')&&ch!='-')
        ch=getchar();
    if(ch=='-')
        f=-1,ch=getchar();
    while('0'<=ch&&ch<='9')
        x=x*10+ch-'0',
        ch=getchar();
    return x*f;
} 
int cnt=0;
inline void dfs(int cur,int dep){
    fst[cur]=++cnt;
    dfn[cnt]=cur;
    depth[cnt]=dep+1;
    for(int i=0;i<g[cur].size();i++){
        int t=g[cur][i];
        if(fst[t]==0){
            dfs(t,dep+1);
            dfn[++cnt]=cur;
            depth[cnt]=dep+1;
        }
    }
}
inline void ST_init(){
    for(int i=1;i<=cnt;i++){
        st[i][0]=i;
    }
    int a,b;
    for(int j=1;j<=lg[cnt];j++){
        for(int i=1;i+(1<<j)-1<=cnt;i++){
            a=st[i][j-1];
            b=st[i+(1<<j-1)][j-1];
            if(depth[a]<depth[b])
                st[i][j]=a;
            else
                st[i][j]=b;
        }
    }
}
inline int lca(int x,int y){
    x=fst[x];y=fst[y];
    if(x>y)
        swap(x,y);
    int k=lg[y-x];
    int a=st[x][k];
    int b=st[y-(1<<k)+1][k];
    if(depth[a]<depth[b])
        return dfn[a];
    else
        return dfn[b];
} 
int main(){
    lg[0]=-1;
    for(int i=1;i<maxn*2;i++)
        lg[i]=lg[i>>1]+1;
    int n,m,s;
    n=read(),m=read(),s=read();
    //scanf("%d%d%d",&n,&m,&s);
    int x,y;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        //scanf("%d%d",&x,&y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs(s,0);
    ST_init();
    while(m--){
        x=read();y=read();
        cout<<lca(x,y)<<endl; 
    }
    return 0;
}
/*
5 5 4
3 1
2 4
5 1
1 4
2 4
3 2
3 5
1 2
4 5
*/

tarjanLCA(提交时记得选c++11)

#include<bits/stdc++.h>
using namespace std;
const int maxn=5e5+5;
struct qnode{
    int x,y,lca;
}query[maxn];
int fa[maxn];
bool vis[maxn];
vector<int> g[maxn],q[maxn];
void init(){
    for(int i=0;i<maxn;i++)
        fa[i]=i;
}
int find(int x){
    if(fa[x]==x)
        return x;
    else
        return fa[x]=find(fa[x]);
}
void tarjan(int u){
    vis[u]=1;
    for(auto qid:q[u]){
        if(query[qid].x==u){
            if(vis[query[qid].y])
                query[qid].lca=find(query[qid].y);
        }
        else if(vis[query[qid].x])
                query[qid].lca=find(query[qid].x); 
    }
    for(auto v:g[u]){
        if(vis[v])
            continue;
        tarjan(v);
        fa[v]=u;
    }
}
int main(){
    init();
    int n,m,s;
    cin>>n>>m>>s;
    int x,y;
    for(int i=1;i<n;i++){
        scanf("%d%d",&x,&y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    for(int i=1;i<=m;i++){
        scanf("%d%d",&query[i].x,&query[i].y);
        q[query[i].x].push_back(i);
        q[query[i].y].push_back(i);
    }
    tarjan(s);
    for(int i=1;i<=n;i++)
        printf("%d\n",query[i].lca);
}

本蒟蒻第一篇博客历时三天,到此写完。

updated on 2019/8/5: