[算法]tarjan 求强连通分量

· · 个人记录

啥是强连通分量

\text{tarjan} 算法目标

将图分成若干个强连通分量,并求出每个强连通分量里的点的信息。

\text{tarjan} 算法思想

这东西,不画图根本理解不了,所以贡献上一张有向图:

\text{tarjan}$ 算法的核心思想是 $\text{dfs}$,这里的 $\text{dfs}

方式默认为先访问当前节点再访问指向节点,并且,我们只会遍历以前没有遍历过的点,遍历过的点直接无视。

我们创建数组 \text{dfn}\text{low} 数组,含义如下:

我们现在可以手动求一下这两个值,我们定义有序数对 (\text{dfn, low}),每个节点都用这个有序数对表示,那么一路往下可得:

然后我们需要回溯,第一步回溯的就是 4 节点,由于 4 节点可以通过那条曲线连接到 2,且无法获取更小的 \text{dfn},所以 \text{low}_4 = 2,图如下:

此时再回溯到 3 节点,发现可以通过走到 4 再走到 2,或者说,4\text{low}3 的小,所以 4 可以更新 3\text{low},图如下:

此时回溯到 2,我们发现 2\text{low} 无法被更新了,所以又回溯到 1,发现 1 也是如此,算法结束。

所以 \text{low} 到底有什么用?其实我们还要维护一个栈,具体看如下图片:

现在我们发现,当一个点的 \text{low} 已经被完全更新完了后,如果它的 \text{dfn} == \text{low},那么就出栈直至弹出了这个点为之。那么求这个 \text{low} 有没有更形象的求解方法呢?当然有,这里下面分两种情况更新 \text{low}

if(!dfn[edges[i].to]){   //这里一定要是这样判断
    tarjan(edges[i].to);  //先 dfs 指向节点
    low[x] = min(low[x], low[edges[i].to]);   //在 low[x] 与 low[指向节点] 中取最小值
}
else if(vis[edges[i].to]){  //vis 记录是否出过栈,出过栈为 0,在栈里为 1
    low[x] = min(low[x], low[edges[i].to]);
}

然后,一个备受争议的地方出现了,就是第二种情况的

low[x] = min(low[x], LOW[edges[i].to])

我大写的地方是备受争议的地方,很多人纠结这里到底写 \text{dfn} 还是 \text{low},其实很容易证明,代码里的 low[edges[i].to] 永远只可能等于他本身,如果不等于,那么肯定在它前面有一个强连通分量,甚至是环套环都没问题,因为已经被出栈了,所以不可能走到这一步,所以写 low[edges[i].to] 或者 dfn[edges[i].to] 都是没问题的。

合起来判断的代码:

for(int i = head[x]; i; i = edges[i].next){  //这里我用的是链式前向星
    if(!dfn[edges[i].to]){
        tarjan(edges[i].to);
        low[x] = min(low[x], low[edges[i].to]);
    }
    else if(vis[edges[i].to]){
        low[x] = min(low[x], low/dfn[edges[i].to]);
    }
}

注意,low/dfn 指你哪种都可以写,都没问题。

然后就是简单的判断环节:

if(dfn[x] == low[x]){
    cnt++;  //这里维护强连通分量的个数
    int tmp;
    do{
        tmp = stk.top();
        stk.pop();
        color[tmp] = cnt;  //这里是维护一些点的信息,维护点在哪个强连通分量里
        vis[tmp] = false;   //出栈当然设为 false
    }while(tmp != x);   //这里要用 do-while,否则会多一个

}

\text{tarjan} 代码

缩点模板题代码,不懂请私信:


#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N = 1e5 + 5;

int n, m;
int a[N], u[N], v[N], in[N], newa[N], f[N];
int idx, dfn[N], low[N], color[N], sum[N], cnt;
bool vis[N];

stack<int> stk;

int head[N], tot;

struct Node{
    int to, next;
}edges[N];

void add(int u, int v){
    tot++;
    edges[tot].to = v;
    edges[tot].next = head[u];
    head[u] = tot;
}

void tarjan(int x){
    dfn[x] = low[x] = ++idx;   //这里时间戳从 1 开始
    stk.push(x);  //弹入栈
    vis[x] = true;
    for(int i = head[x]; i; i = edges[i].next){
        if(!dfn[edges[i].to]){
            tarjan(edges[i].to);
            low[x] = min(low[x], low[edges[i].to]);
        }
        else if(vis[edges[i].to]){
            low[x] = min(low[x], dfn[edges[i].to]);
        }
    }
    if(dfn[x] == low[x]){
        cnt++;
        int tmp;
        do{
            tmp = stk.top();
            stk.pop();
            color[tmp] = cnt;
            vis[tmp] = false;
        }while(tmp != x);
    }
}

void topo(){ //拓扑排序
    queue<int> q;
    for(int i = 1; i <= cnt; i++){
        if(!in[i]){
            f[i] = newa[i];
            q.push(i);
        }
    }
    while(!q.empty()){
        int x = q.front();
        q.pop();
        for(int i = head[x]; i; i = edges[i].next){
            in[edges[i].to]--;
            f[edges[i].to] = max(f[edges[i].to], f[x] + newa[edges[i].to]);
            if(!in[edges[i].to]){
                q.push(edges[i].to);
            }
        }
    }
}

void Solve(){
    cin >> n >> m;
    for(int i = 1; i <= n; i++){
        cin >> a[i];
    }
    for(int i = 1; i <= m; i++){
        cin >> u[i] >> v[i];
        add(u[i], v[i]);
    }
    for(int i = 1; i <= n; i++){  //这里因为图不连通
        if(!dfn[i]){
            tarjan(i);
        }
    }
    memset(head, 0, sizeof(head));
    for(int i = 1; i <= m; i++){   //重构图
        if(color[u[i]] != color[v[i]]){
            in[color[v[i]]]++;
            add(color[u[i]], color[v[i]]);
        }
    }
    for(int i = 1; i <= cnt; i++){
        for(int j = 1; j <= n; j++){
            if(color[j] == i){
                newa[i] += a[j];
            }
        }
    }
    topo();   //这里不属于 tarjan,是题目要求
    int maxi = -1e9;
    for(int i = 1; i <= cnt; i++){
        maxi = max(maxi, f[i]);
    }
    cout << maxi;
}

signed main(){
  Solve();
  return 0;
}