LCA算法
inoichi_lim
2020-03-22 12:11:42
$$\texttt{LCA算法By Billy2007}$$
1. 前言
众所周知Billy2007树论不行,为了记住求LCA的方法,于是开了这个天坑。
> LCA(Least Common Ancestors),即最近公共祖先,是指在有根树中,找出某两个结点u和v最近的公共祖先。
为了方便,我们这里定义:
```cpp
const int N=500001;
```
感谢@夏夜空 的题解!
2. 怎么求LCA
以[P3379](https://www.luogu.com.cn/problem/P3379)为例。
2.1 暴力
暴力算法:让这两个点像并查集一样暴力搜索,直到第一次相遇为止。
```cpp
struct node{
int bianh;
int fa,chd[N/1000];//考虑乐观情况
node(){
fa=0;
bianh=0;
memset(chd,0,sizeof(chd));
//这里3个都要定为0.
}
};
int lca(int a,int b){
if(a==b) return a;
return lca(nd[a].fa,nd[b].fa);
}
```
但是这个算法有问题。
显然我们不知道是 **$a$是$b$的父亲还是$b$是$a$的父亲**。
这个可以通过`dfs`解决,具体代码在后面。
并且空间只考虑**乐观**情况,可能到时候会出现$1$是根,而$2$到$N$都是兄弟的情况。
这样会造成空间的**大量浪费**。
并且就算可行,初始化`dfs`时间复杂度$O(n)$,每一次最坏时间复杂度$O(n)$,总时间复杂度$O(nm)$,肯定会被T飞(~~废话~~)。
----
2.2 倍增
我们可以使用倍增算法——一种通过~~抄题解~~记录$a$的前$2^i$代来达到$O(\log n)$求LCA的算法。
我们假设有一棵$n$点$m$边的**多叉**树。
一个一个点存显然不现实,那咋办?
- 可以把静态的数组转换成`vector`,更为方便。
- 从存每一个点的父亲变成存每一个点往上数的第$2^k(k\le 22)$个祖先。
转换后,一棵树的定义如下:
```cpp
vector<int> sons[N];//sons[i][j]表示i的第j个儿子。
int dep[N],lga[N],n,m,s,fa[N][22];//这里dep[i]表示第i个点的深度,lga[i]表示lb(i)+1。n,m,s同题目要求。
```
注:$lb(i)=log_2^i$。
这里$fa_{i,j}$表示$i$往上数的第$2^j$个祖先。因为$2^{22}=4194304$,远远大于$N$。
可能有人会问了:为什么这里$fa_{i,j}$存 **$i$往上数的第$2^j$个祖先**,而不存 **$i$往上数的第$j$个祖先**?
原因如下:
- $O(N\times N)$的空间复杂度显然$\color{darkblue}MLE$;
- 每次遇到一个点都要往上爬看他的祖先,时间复杂度巨大。如果这棵树是一条链,那么遍历这棵树的时间复杂度是$O(\sum_{i=0}^{n-1}i)=O(\frac{n(n-1)}{2})=O(n^2)$。(~~注:这里我也不知道为啥~~)
3. 倍增的实现
显然我们这里需要先做一个连接的函数。
```cpp
void add(int s,int t){
sons[s].push_back(t);
}
```
接下来,我们写一个求$lb(i)$的函数。
```cpp
void init(int to){
for(int i=1;i<=to;i++){
lga[i]=lga[i-1];
if(i==1<<lga[i-1]) lga[i]++;
}
}
```
接下来,还是那个老问题。
所以,我们需要一个`dfs`函数。
```cpp
void dfs(int now,int fanow){//参数fanow就是now他爸
dep[now]=dep[fanow]+1;//记录深度
fa[now][0]=fanow;
for(int i=1;(1<<i)<=dep[now];i++){
//记录祖先
fa[now][i]=fa[fa[now][i-1]][i-1];
}
for(int i=0;i<sons[now].size();i++){//注意,这里是vector
if(sons[now][i]!=fanow){//我们在加边是因为还没有dfs,所以把父亲也加入了sons
dfs(sons[now][i],now);
}
}
}
```
在记录祖先的过程中,这里$i$对应第$2^i$层,$now$往上第$2^i$祖先就是$now$往上第$2^{i-1}$个祖先的第$2^{i-1}$个祖先。
比如下图的树。
![](https://cdn.luogu.com.cn/upload/image_hosting/qnluy3uc.png?x-oss-process=image/resize,m_lfit,h_1770,w_2025)
(~~画的比较丑,不要喷~~)
显然$2$的第$2^0$个祖先是$1$,那么$5$的第$2^1$个祖先就是$5$的第$2^0$个祖先的第$2^0$个祖先。
哇哇哇,递归!
形成了递归关系后,那么接下来就好做多了。
4. 题外话
是不是所有 **实数$a$** 都可以用$\sum_{i=0}^n 2^i\times k_i(k_i\in \{0,1\})$表示呢?
答案是对。
显然任何一个**实数$a$** 都可以表示成二进制表示,那么二进制就是上式的形式。
5. 最核心的部分
假设$b(x)=2^x$。
```cpp
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);//这里为了方便,让x永远比y深
while(dep[x]>dep[y]) x=fa[x][lga[dep[x]-dep[y]]-1];
//这里的lga是lb(dep[x]-dep[y]+1),所以还要-1。
//这里就是直接让x跳到和y同一深度,再接下去搞。
//个人不明白为什么还要写个while。
//现在懂了,可能需要跳很多次(比如8->4->1)才可以到。
if(x==y) return x;
//如果一跳,得,直接一样了,那么lca就是x了(写y也可以/cy)
for(int i=lga[x];i>=0;i--){
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
/*
全场最NB代码
如果现在两个节点的b(x)层父亲相等,那么说明步子跨大了。
那么就不要跨了。
否则,就说明现在还需要跨。
那么就继续给我跨。
如果现在两个点已经到了合并态(注:比如上面那棵树的2和3,4和5),那么就不要管他了。
jojo!时间复杂度O(lb(n))!
*/
}
return fa[x][0];//这里x他爸就是lca,因为现在x和y已经到了合并态。
}
```
6. 完整代码
```cpp
#include<bits/stdc++.h>
#define ns "-1"
#define fs(i,x,y,z) for(ll i=x;i<=y;i+=z)
#define ft(i,x,y,z) for(ll i=x;i>=y;i+=z)
#define ll long long
#define ull unsigned long long
#define db double
#define ms(a,b) memset(a,b,sizeof(a))
#define sz(a) sizeof(a)
using namespace std;
const int N=500001,inf=0x7f7f7f7f;
vector<int> sons[N];
int dep[N],lga[N],n,m,s,fa[N][22];
void add(int s,int t){
sons[s].push_back(t);
}
void init(int to){
for(int i=1;i<=to;i++){
lga[i]=lga[i-1];
if(i==1<<lga[i-1]) lga[i]++;
}
}
void dfs(int now,int fanow){
dep[now]=dep[fanow]+1;
fa[now][0]=fanow;
for(int i=1;(1<<i)<=dep[now];i++){
fa[now][i]=fa[fa[now][i-1]][i-1];
}
for(int i=0;i<sons[now].size();i++){
if(sons[now][i]!=fanow){
dfs(sons[now][i],now);
}
}
}
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) x=fa[x][lga[dep[x]-dep[y]]-1];
if(x==y) return x;
for(int i=lga[x];i>=0;i--){
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
}
return fa[x][0];
}
inline int read(){
int date=0,w=1;char c=0;
while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();}
while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();}
return date*w;
}
int main(){
n=read();m=read(),s=read();
init(n);
for(int i=1;i<n;i++){
int x,y;x=read(),y=read();
add(x,y);
add(y,x);
}
dfs(s,0);
for(int i=1;i<=m;i++){
int x,y;x=read(),y=read();
printf("%d\n",lca(x,y));
}
return 0;
}
```