最近公共祖先,LCA

· · 个人记录

LCA(Least Common Ancestors),即最近公共祖先,

是指在有根树中,找出某两个结点u和v最近的公共祖先。

———来自百度百科

P3379 【模板】最近公共祖先(LCA)

题目描述

当然,这题肯定不是用n^2暴力,那它用了什么优化呢?

法一:倍增(提交记录 3.18 s)

倍增算法时间复杂度为O(nlogn)

这个方法是太久以前写的,码风奇特解释奇怪,不喜欢的别不看了啊

所谓倍增,就是按2的倍数来增大,也就是跳 1,2,4,8,16,32……

从大向小跳,即按32,16,8,4,2,1……来跳,如果大的跳不过去,再把它调小

(可以拿二进制为例,11(1011),从高位向低位填很简单,如果填了这位之后比原数大了,那我就不填)

所以算法的时间复杂度为O(nlogn)

首先我们要记录各个点的深度和他们2^i级的的祖先,

用数组depth表示每个节点的深度,fa[i][j]表示节点i的2^j级祖先

#include<bits/stdc++.h>
using namespace std;
const int M = 500005;
vector <int> E[M];
void add(int x, int y) {
    E[x].push_back(y);
}
int depth[M],fa[M][25],lg[M];
void dfs(int now, int fath) {
    fa[now][0]=fath;
    depth[now]=depth[fath] + 1;
    for(int i=1;(1<<i)<=depth[now];i++)
        fa[now][i] = fa[fa[now][i-1]][i-1];
//意思是now的2^i祖先等于now的2^(i-1)祖先的2^(i-1)祖先
    for(int i=0;i<E[now].size();i++)
        if(E[now][i]!=fath) dfs(E[now][i],now);
}
int LCA(int x, int y) {
    if(depth[x]<depth[y]) swap(x, y);
//令x的深度 >= y的深度 
    while(depth[x]>depth[y])
        x=fa[x][lg[depth[x]-depth[y]]-1];
//先跳到同一深度
    if(x==y) return x;
//如果x是y的祖先,那他们的LCA肯定就是x
    for(int k=lg[depth[x]]-1; k>=0;k--)
        if(fa[x][k]!=fa[y][k])
//因为我们要跳到它们LCA的下面一层(为什么?因为直接跳到它们的LCA,因为这可能会误判,可能一下子跳太过头了)
            x=fa[x][k],y=fa[y][k];
    return fa[x][0];
//返回父节点
}
int main() {
    int n,m,s;
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1;i<n;i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    for(int i=1;i<=n;i++)//预先算出log_2(i)+1的值,使2^lg[i]>i,用的时候直接调用就可以了
        lg[i]=lg[i-1]+(1<<lg[i-1]==i);//当2^(lg[i-1])==i,为使2^lg[i]>i,lg[i]=lg[i-1]+1 
    dfs(s,0);
    for(int i=1;i<=m;i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        printf("%d\n",LCA(x,y));
    }
    return 0;
}

法二:ST算法+RMQ(提交记录 3.94 s)

ST算法时间复杂度是O(n+q+nlogn)

ST算法

ST(Sparse Table)算法是一个非常有名的在线处理RMQ问题的算法,它可以在O(nlogn)时间内进行预处理,然后在O(1)时间内回答每个查询。

P3865 【模板】ST 表

先确定上面那个算法你会了再来学这个方法!!!

反正它最慢,不学也没事

ST算法解LCA

先试试把访问顺序和深度写出(比如样例):

访问顺序:42135 4((2),(1((3),(5)))

深度:   12233 1((2),(2((3),(3)))

然后如何看出它是什么样的一棵树呢

有一种选择是在儿子节点中间插入父亲,就像这样:

访问顺序:424131514 4((2),4,(1((3),1,(5),1)),4)

深度:    121232321 1((2),1,(2((3),2,(3),2)),1)

我们用

可以用一遍dfs算出所有值:

void dfs(int u,int depth) {
    dep[++cnt]=depth;
    p[cnt]=u;
    first[u]=cnt;
    for(int i=0; i<e[u].size(); i++) {
        int v=e[u][i];
        if(first[v]) continue;
        dfs(v,depth+1);
        p[++cnt]=u,dep[cnt]=depth;
    }
}
访问顺序 p[]:424131514 4((2),4,(1((3),1,(5),1)),4)

深度   dep[]:121232321 1((2),1,(2((3),2,(3),2)),1)

LCA就是在这个深度序列中取 first[x]~first[y],找到深度最小值并输出这个最小值的编号

在查询 first[x]~first[y] 中我们可以记录最小的深度dep[i]和它的编号i, 那么LCA(x,y)就是p[i]

(不懂这两句没关系,我下面会慢慢解释的)

现在我们来看看样例:

    p[ ] = { 4,2,4,1,3,1,5,1,4 }

  dep[ ] = { 1,2,1,2,3,2,3,2,1 }

first[ ] = { 4,2,5,1,7 }

比如算一算LCA(2,4)

LCA(2,4) -> min( dep[ first[2]~first[4] ] ) 中最小的深度的 p[i] (转化为编号)

first[2]=2, first[4]=1

dep[1~2]中最小的是dep[1]=1,那么p[1]=4就是答案

是不是有点感觉了?

再比如LCA(3,2) -> min( dep[ first[3]~first[2] ] ) 中最小的深度的 p[i]

first[3]=5, first[2]=2

dep[2~5]中最小的是dep[3]=1,那么p[3]=4就是答案

是不是感觉有点懂了?

再来一个LCA(3,5) -> min( dep[ first[3]~first[5] ] ) 中最小的深度的 p[i]

first[3]=5, first[5]=7

dep[5~7]中最小的是dep[6]=2,那么p[6]=1就是答案

好,你会了 然而并没有

完了吗?没有

为什么?因为我们写了半天还是个n^2暴力。。。 所有没看懂的别往下看了,先把暴力学会吧你!

但我们还能优化 没优化那我学个锤子, 我们已经成功把LCA转化成了区间最小值(只不过在取最小值时要把编号也记录下来),那么我们想到了什么?线段树!!! st表

在查询最小值时用st表优化就可以nlogn预处理,O(1)查询

完整代码

#include<bits/stdc++.h>
using namespace std;
const int MAXN=1000005;
vector<int> e[MAXN];
int n,m,root,cnt=0;
int p[MAXN],dep[MAXN],first[MAXN];
//p[i]表示dfs第i个访问的结点
//dep[i]表示p[i]的深度
//first[i]表示第一次访问p[i]出现的下标
void dfs(int u,int depth) {
    dep[++cnt]=depth;
    p[cnt]=u;
    first[u]=cnt;
    for(int i=0; i<e[u].size(); i++) {
        int v=e[u][i];
        if(first[v]) continue;
        dfs(v,depth+1);
        p[++cnt]=u,dep[cnt]=depth;
    }
}
int lg[MAXN],f[MAXN][20],id[MAXN][20];
void ST() {//求1~cnt内dep的最小值
    for(int i=1; i<=cnt; i++)
        lg[i]=lg[i>>1]+1;
    for(int i=1; i<=cnt; i++)
        f[i][0]=dep[i],id[i][0]=p[i];
    for(int j=1; j<=lg[cnt]; j++)
        for(int i=1; i+(1<<j)-1<=cnt; i++) {
            if(f[i][j-1]<=f[i+(1<<(j-1))][j-1])
                f[i][j]=f[i][j-1],id[i][j]=id[i][j-1];
            else
                f[i][j]=f[i+(1<<(j-1))][j-1],id[i][j]=id[i+(1<<(j-1))][j-1];
        }
}
inline int query(int l,int r) {
    int k=lg[r-l+1]-1;
    if(f[l][k]<f[r-(1<<k)+1][k]) return id[l][k];
    return id[r-(1<<k)+1][k];
}
int LCA(int x,int y) {
    x=first[x];
    y=first[y];
    if(x>y) swap(x,y);
    return query(x,y);
}
int main() {
    scanf("%d%d%d",&n,&m,&root);
    for(int i=1; i<n; i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    dfs(root,1);
    ST();
    while(m--) {
        int x,y;
        scanf("%d%d",&x,&y);
        printf("%d\n",LCA(x,y));
    }
    return 0;
}

法三:Tarjan离线法(提交记录 2.60 s)

Tarjan离线法详细解释

Tarjan(离线)算法时间复杂度是O(n+q)。

离线算法

离线算法( off line algorithms),是指基于在执行算法前输入数据已知的基本假设,也就是说,对于一个离线算法,在开始时就需要知道问题的所有输入数据,而且在解决一个问题后就要立即输出结果。

什么玩意?

我只知道 通俗来说,离线算法就是先把所有问题存下来,然后在求解时把答案填到一个数组(不一定按顺序),然后一起把答案输出

那么如果这题考虑用离线算法,就要把所有LCA算一遍

如果把树画出来,我们先从下往上看

会发现树其实是一个最终合到只剩根的并查集

那遍历时先往下搜再直接合并呗

E[u][i] 指u连的边( u->E[u][i] )

    for(int i=0; i<E[u].size(); i++) {
        int v=E[u][i];
        if(vis[v]) continue;
        Tarjan(v);
        fa[v]=u;
    }

那怎么填答案呢?

答案是什么?其实LCA(u,v)就是搜索到u时的find(v)

e[u][i].first 指一个问题是求LCA(u,e[u][i].first)

e[u][i].second 指LCA(u,e[u][i].first)是第几个问题

    for(int i=0; i<e[u].size(); i++) {
        int v=e[u][i].first;
        if(!vis[v]) continue;
        ans[e[u][i].second]=find(v);
    }

完整代码

#include<bits/stdc++.h>
using namespace std;
const int M = 500005;
vector <int> E[M];
vector <pair<int,int> > e[M];
int fa[M],ans[M];
bool vis[M];
int find(int x) {
    if(fa[x]==x) return x;
    return fa[x]=find(fa[x]);
}
void Tarjan(int u) {
    vis[u]=1;
    for(int i=0; i<E[u].size(); i++) {
        int v=E[u][i];
        if(vis[v]) continue;
        Tarjan(v);
        fa[v]=u;
    }
    for(int i=0; i<e[u].size(); i++) {
        int v=e[u][i].first;
        if(!vis[v]) continue;
        ans[e[u][i].second]=find(v);
    }
}
int main() {
    int n,m,s;
    scanf("%d%d%d",&n,&m,&s);
    for(int i=1; i<=n; i++)
        fa[i]=i;
    for(int i=1; i<n; i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        E[x].push_back(y);
        E[y].push_back(x);
    }
    for(int i=1; i<=m; i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        e[x].push_back({y,i});
        e[y].push_back({x,i});
    }
    Tarjan(s);
    for(int i=1; i<=m; i++)
        printf("%d\n",ans[i]);
    return 0;
}

这种思想还挺有趣的,可惜题目不多,这里只有一道 P5838 [USACO19DEC]Milk Visits G 双倍经验:SP11985 GOT - Gao on a tree。(从 RedreamMer 大佬博客捞的)

可能离线都被莫队做掉了。

法四:树链剖分(提交记录 2.46 s)

树链剖分算法时间复杂度是O(2n+mlogn) 但它常数小,所以跑的最快

别看板子题是道蓝题我用它来做黄题很搞笑

其实它在这里只是两边dfs,而板子题还得路径上加加改改,用线段树优化,把心态敲炸

树链剖分核心思想是把树转化成一条条链

一些概念(大概了解即可)

  1. 重儿子:对于每一个非叶子节点,它的儿子中以那个儿子为根的子树节点数最大的儿子为该节点的重儿子
  2. 轻儿子:对于每一个非叶子节点,它的儿子中非重儿子的剩下所有儿子即为轻儿子
  3. 叶子节点:没有重儿子也没有轻儿子(因为它没有儿子)
  4. 重边:一个父亲连接他的重儿子的边称为重边
  5. 轻边:剩下的即为轻边
  6. 重链:相邻重边连起来的连接一条重儿子的链叫重链

dfs1()

这个dfs要处理这几件事情:

int dfs1(int u,int f,int depth) {
    dep[u]=depth;
    fa[u]=f;
    siz[u]=1;
    int maxson=-1;
    for(int i=0; i<e[u].size(); i++) {
        int v=e[u][i];
        if(v==f) continue;
        siz[u]+=dfs1(v,u,depth+1);
        if(siz[v]>maxson) maxson=siz[v],son[u]=v;
    }
    return siz[u];
}

dfs2()

这个dfs要处理这几件事情:

dfs2顺序是先重再轻,这样每一条重链的新编号就是连续的

又因为是dfs,所以每一个子树的新编号也是连续的

void dfs2(int u,int topf) {//topf 当前链的最顶端的节点
    dfn[u]=++cnt;
    top[u]=topf;
    if(!son[u]) return ;//如果没有儿子则返回 
    dfs2(son[u],topf);//按先处理重儿子,再处理轻儿子的顺序递归处理
    for(int i=0; i<e[u].size();i++)
        if(!dfn[e[u][i]])
            dfs2(e[u][i],e[u][i]);//对于每一个轻儿子都有一条从它自己开始的链
}

考虑到可以用top数组快速往上跳,每次看看top是不是一样,不一样就继续

(top[ ]:每个点的所在链顶点)

(相当于把top当成fa暴力往上跳,最后跳到一条链上)

那么再看看谁的深度浅就输出谁

(一条链上的LCA就看谁的深度浅就输出谁)

别看瞎扯了这么多,核心代码超短

int LCA(int x,int y) {
    while(top[x]!=top[y]) {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }   
    return dep[x]<dep[y]?x:y;
}

完整代码

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2000005;
vector<int> e[MAXN];
int n,m,root,cnt=0;
int dep[MAXN],fa[MAXN],son[MAXN],siz[MAXN],top[MAXN],dfn[MAXN];
int dfs1(int u,int f,int depth) {
    dep[u]=depth;
    fa[u]=f;
    siz[u]=1;
    int maxson=-1;
    for(int i=0; i<e[u].size(); i++) {
        int v=e[u][i];
        if(v==f) continue;
        siz[u]+=dfs1(v,u,depth+1);
        if(siz[v]>maxson) maxson=siz[v],son[u]=v;
    }
    return siz[u];
}
void dfs2(int u,int topf) {
    dfn[u]=++cnt;
    top[u]=topf;
    if(!son[u]) return ;
    dfs2(son[u],topf);
    for(int i=0; i<e[u].size(); i++)
        if(!dfn[e[u][i]])
            dfs2(e[u][i],e[u][i]);
}
int LCA(int x,int y) {
    while(top[x]!=top[y]) {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }   
    return dep[x]<dep[y]?x:y;
}
int main() {
    scanf("%d%d%d",&n,&m,&root);
    for(int i=1; i<n; i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    dfs1(root,0,1);
    dfs2(root,root);
    while(m--) {
        int x,y;
        scanf("%d%d",&x,&y);
        printf("%d\n",LCA(x,y));
    }
    return 0;
}