最小树形图

· · Algo. & Theory

定义

有向树

对于无向图中的树来说,我们不用区分一个边是从深度小的点指向深度大的点,甚至不用考虑那个是根节点。但是如果树中的边都是有向边呢?很明显在两个点之间建两条边是很粗暴的,而且这样就相当于无向图的最小生成树了。

我们规定,两个点之间的边一定是从深度小的点指向深度大的点(即外向树),这样的树一定是有根的,根就是入度为 0 的节点。下面给出了一棵外向树:

根据定义,我们可以知道外向树中除了根节点的入度为 0 之外的所有结点的入度都为 1;如果一个有向图满足以上性质并不存在有向环,那它也同时肯定是一个外向树,正确性显然。

树形图

对于一个有向图来说,给定一个根节点 r,以这个节点为根节点且包含了这个有向图中所有节点的外向树叫做有向图的树形图(以 r 为根)。下面给出了一个有向图的树形图(以 a 为根):

其中标红的边为树形图上的边。

最小树形图

和最小生成树一样,我们定义最小树形图为所有树形图中边权和最小的那一个。

下图给出了一个图的最小树形图(以 a 为根):

其中红色的边为最小树形图上的边(这里最小树形图的边权和为 6)。

最小树形图的求解

以下假设我们求的是以某个节点为根的最小树形图。

求解最小树形图主要有两种方法,O(n(n+m)) 的朱刘算法和 O(m+n\lg m) 的来自 Tarjan 的 DMST 算法(使用斐波那契堆可以优化到 O(m+n\lg n) 的,但是一般没必要)。

朱刘算法(也称 Edmonds 算法)

有向无环图上的最小树形图

这个时候我们对于除了根节点之外的每个节点都选择一条边权最小的入边加入一个边集合中,如果每个非根节点都有这样一条入边,并且由于有向无环图没有有向环,那么就满足了外向树的性质。所以这个边集合就是这个有向无环图的最小树形图。

有向环的最小树形图

如果有根,那么把连接到根节点的那条边断开就是最小树形图;如果无根,我们只需要把权值最大的一条边断开即可。

有向图上的最小树形图

我们先使用处理有向无环图的最小生成树的算法来对有向图进行处理。如果有非根节点没有入边,那么肯定无解了。

如果边的集合中没有环,那这就是最小树形图集合了;如果有环,那么我们需要在这个环上断开一条边,再从其它节点连一条边过来。

这里需要注意:因为在当前的边的集合中每个节点至多有 1 条入边,所以每个环上的节点一定没有从环外指来的边。换句话说就是每个环都是独立的。

那么从环外连的这条边就会顶替掉它指向的环内的节点原来的入边。于是这条边产生的新的贡献就是它的边权减去那个环内节点的入边的边权。

因为对于每个指向环的边除了权值以外与环内节点就没有什么关系了,我们就可以把每个环缩成点,同时把不在环上的单个点也视作环,再在这个新图上用上面说的方法连上有向边就行了。

这就是朱刘算法。因为每次至少会把 2 个节点缩成 1 个点,所以至少会减少一条边。于是算法最多会跑 n-1 次。并且每次循环的复杂度是 O(n+m) 的,算法的时间复杂度就是 O(n(n+m)) 的了(事实上随机数据下跑得很快)。

我们再梳理一遍朱刘算法的过程:

  1. 找到每个点的最小入边
  2. 检测是否有环
    • 有环:把环缩成点,改变边权,跳转到第 1 步。
    • 无环:算法结束,返回答案。

实现

思路是有了,具体的代码怎么实现呢?这里着重讲一下找环的代码。

如果你不嫌麻烦的话,可以写 tarjan 找强连通分量。但是这个图中每个环都是互相独立的,我们不需要那么麻烦。

对于一个节点 u,假设我们现在需要找从它往自己所连的入边那里跳能找的的环,很明显如果存在就只存在这一个。那么我们用一个临时变量 v=u,然后让 v 一直往上跳,把 v 经过的所有节点就标记上 u,这里用一个 low 数组来存,标记就让 low[v]=u。当然如果 v 跳到根节点就说明不可能存在环了。

v 跳到一个已经有标记的节点,我们需要分情况讨论:

代码实现如下(注:col[i] 为节点 i 所在的环的编号,cnt 为当前环的数量,rt 为树形图的根节点,iu[i]i 节点所连入边的源点):

for (int u = 1; u <= n; u++) {
    if (col[u]) continue; // 如果当前节点已经被标记过了,当然不用继续了
    int v = u;
    while (v != rt && !low[v]) low[v] = u, v = iu[v]; // 往上跳
    if (v != rt && low[v] == u) { // 找到一个新的环
        // 给环上每个节点标记所属的环
        col[v] = ++cnt;
        for (int i = iu[v]; i != v; i = iu[i]) col[i] = cnt;
    }
}
// 给不在环上的节点标上编号
for (int i = 1; i <= n; i++) if (!col[i]) col[i] = ++cnt;

因为每个点的 low 至多被更新一次,所以这段代码的时间复杂度是 O(n) 的。

#include <iostream>

using namespace std;

char buf[1 << 21], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
int read() {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    return x * f;
}

char wbuf[(1 << 21) + 1], *p3 = wbuf;

#define flush (fwrite(wbuf, 1, p3 - wbuf, stdout), p3 = wbuf)
#define putchar(__x__) (p3 == wbuf + (1 << 21) ? flush : p3, (*p3++) = (__x__))
#define endl putchar('\n')
#define space putchar(' ')

void write(int x) {
    static int stk[100], top;
    if (!x) return void(putchar('0'));
    if (x < 0) putchar('-'), x = -x;
    top = 0;
    while (x) stk[++top] = x % 10, x /= 10;
    while (top) putchar(stk[top--] + '0');
}

void write(const char* str) {
    for (int i = 0; str[i]; i++) putchar(str[i]);
}

#include <cstring>

constexpr int N = 110, M = 10010, inf = 0x7fffffff;
int n, m, s;
int h[N], eu[M], ev[M], ew[M], ne[M], idx;
int iu[N], iw[N], col[N], cnt, low[N];

void addedge(int u, int v, int w) {
    ++idx;
    eu[idx] = u, ev[idx] = v, ew[idx] = w;
    ne[idx] = h[u];
    h[u] = idx;
}

int edmonds(int rt) {
    int res = 0;
    while (true) {
        for (int i = 1; i <= n; i++) iw[i] = inf;
        for (int i = 1; i <= idx; i++)
            if (ew[i] < iw[ev[i]])
                iw[ev[i]] = ew[i], iu[ev[i]] = eu[i];
        iw[rt] = iu[rt] = 0;
        for (int i = 1; i <= n; i++) {
            if (iw[i] == inf) return -1;
            res += iw[i];
        }
        cnt = 0;
        memset(low, 0, sizeof(low));
        memset(col, 0, sizeof(col));
        for (int i = 1; i <= n; i++) {
            int j = i;
            while (j != rt && low[j] != i && !col[j]) low[j] = i, j = iu[j];
            if (j != rt && !col[j]) {
                col[j] = ++cnt;
                for (int k = iu[j]; k != j; k = iu[k]) col[k] = cnt;
            }
        }
        if (!cnt) return res;
        for (int i = 1; i <= n; i++) if (!col[i]) col[i] = ++cnt;
        int tmp = idx; idx = 0;
        for (int i = 1; i <= tmp; i++)
            if (col[eu[i]] != col[ev[i]])
                addedge(col[eu[i]], col[ev[i]], ew[i] - iw[ev[i]]);
        n = cnt, rt = col[rt];
    }
    return res;
}

int main() {
    n = read(), m = read(), s = read();
    while (m--) {
        int u = read(), v = read(), w = read();
        if (u == v) continue;
        addedge(u, v, w);
    }
    write(edmonds(s));
    flush;
    return 0;
}

DMST 算法

事实上DMST是Direct Minimum Spanning Tree(有向最小生成树即最小树形图)的缩写

让我们回顾一下朱刘算法在一次循环中的步骤:

欸!我们发现了最小值合并两个关键字,可以使用可并堆进行维护一个节点的入边!

于是上面三个操作就分别对应了可并堆上的三个操作:

然后我们就可以动态地去找到一个节点的最小入边,然后有环的话就进行缩环就行了。

当我们把一些环上的节点合并为一个节点时,我们同时也需要再次计算这个环缩成的节点的最小入边。所以我们需要用一个队列或栈(下面是用栈实现)储存接下来需要求最小入边的节点。初始的时候把所有节点放进去,如果缩完一个环就把代表这个环的节点放进去。

下面我们再明确一遍算法过程:

  1. 把所有节点放入栈中,把所有节点的入边加入这个节点的堆
  2. 一直循环直到栈不为空
    • 找到当前节点的最小入边
    • 如果新加入的边使树形图形成环
      • 答案加上环上所有边的边权
      • 把环上的所有边扔掉(再每个环上节点的堆 pop 一次,因为环上的边一定是最小边)
      • 把每个环上的节点的堆中的边的权值减去这个节点在环中的入边的边权(朱刘算法的步骤)
      • 把所有环上的节点的堆合并成一个堆,缩成一个节点
      • 把缩成的节点加入栈中
  3. 最终的答案即为最小树形图的边权和

复杂度分析

首先我们要把所有节点放入栈,并且把所有边放入堆。前者的时间复杂度明显是 O(n) 的,后者我们可以采用线性建堆的方法优化到 O(m) 的。

假设现在要把 k 个节点放入一个左偏树中,并且 k=2^c,c\in N,那么我们可以从下往上合并,时间复杂度为

\begin{align} \sum_{i=0}^c2^{c-i}O(\log_2(2^i))&=O\left(\sum_{i=0}^c2^{c-i}i\right)\\ &=O\left(2^c\sum_{i=0}^{\infty}\frac{i}{2^i}\right)\\ &=O(2^c)=O(k) \end{align}

于是插入的总时间复杂度就是 \sum O(k)=O\left(\sum k\right)=O(m) 的了。

接着是算法的主体。我们先分析缩一个节点个数为 k 的环的时间复杂度。首先要找到每个节点的最小入边,这一步在左偏树上是 \Theta(1) 的,一共 \Theta(k);接着要在每个堆打懒标记,在左偏树上是 \Theta(1) 的,一共 \Theta(k);最后要把所有节点的左偏树合并,直接按顺序合并,一次是 O(\lg m) 的,一共 O(k\lg m)。最后加起来是 O(k\lg m) 的。

如果我们把一个大小为 k 的环缩成一个点,那么产生的复杂度贡献为 k\lg m,并且会让图上的节点减少 k-1 个。所以每次缩环的系欸但的个数和是与 n 同阶的,即 \sum_ik=\Theta(n),所以总复杂度为

\begin{align} \sum_{i}O(k\lg m)&=O\left(\sum_{i}(k\lg m)\right)\\ &=O\left(\left(\sum_ik\right)\lg m\right)\\ &=O(n\lg m) \end{align}

最后两部分加起来就得到了算法的最终复杂度:

O(m+n\lg m)

实现

这里主要有两个问题:怎么判断新加入的边产生了环和维护每个节点所在的环。

第一个问题,像朱刘算法一样,如果我们从要加入的边的出发的节点往上跳,那么就会跳到接下来要加入的点。所以我们可以用并查集加路径压缩维护从一个节点目前能往上跳到的最后面的节点。

第二个问题,我们也可以用并查集加路径压缩直接解决。

这两个并查集所带来的复杂度一共是 O(n\alpha(n)) 的,并不影响复杂度的上界。

#include <cstring>
#include <iostream>
#include <vector>
#define int long long

using namespace std;

char buf[1 << 21], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
int read() {
    int x = 0, f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    return x * f;
}

char wbuf[(1 << 21) + 1], *p3 = wbuf;

#define flush (fwrite(wbuf, 1, p3 - wbuf, stdout), p3 = wbuf)
#define putchar(__x__) (p3 == wbuf + (1 << 21) ? flush : p3, (*p3++) = (__x__))
#define endl putchar('\n')
#define space putchar(' ')

void write(int x) {
    static int stk[100], top;
    if (!x) return void(putchar('0'));
    if (x < 0) putchar('-'), x = -x;
    top = 0;
    while (x) stk[++top] = x % 10, x /= 10;
    while (top) putchar(stk[top--] + '0');
}

void write(const char* str) {
    for (int i = 0; str[i]; i++) putchar(str[i]);
}

constexpr int N = 1e5 + 10, M = 1e6 + 10;
int n, m, s;

struct Edge {
    int u, v, w;
};

struct {
    Edge val[M];
    int ls[M], rs[M], dist[M], tag[M], idx;

    void pushdown(int u) {
        if (ls[u]) val[ls[u]].w += tag[u], tag[ls[u]] += tag[u];
        if (rs[u]) val[rs[u]].w += tag[u], tag[rs[u]] += tag[u];
        tag[u] = 0;
    }

    int merge(int u, int v) {
        if (!u || !v) return u | v;
        if (val[u].w > val[v].w) swap(u, v);
        pushdown(u);
        rs[u] = merge(rs[u], v);
        if (dist[ls[u]] < dist[rs[u]]) swap(ls[u], rs[u]);
        dist[u] = dist[rs[u]] + 1;
        return u;
    }

    void addtag(int u, int v) {
        if (!u) return;
        tag[u] += v, val[u].w += v;
    }

    int build(const vector<Edge>& a, int l, int r) {
        if (l > r) return 0;
        if (l == r) {
            ++idx;
            val[idx] = a[l];
            return idx;
        }
        int mid = (l + r) >> 1;
        return merge(build(a, l, mid), build(a, mid + 1, r));
    }

    void print(int u, int dep = 0) {
        if (!u) return;
        for (int i = 0; i < dep; i++) putchar('-');
        write(u), write(": "), write(ls[u]), space, write(rs[u]), space, write(dist[u]), write(" ["), write(val[u].u), write(", "), write(val[u].v), write(", "), write(val[u].w), write("]\n");
        pushdown(u);
        print(ls[u], dep + 1);
        print(rs[u], dep + 1);
    }
} heap;

struct {
    Edge e[M];
    int h[N], ne[M], idx;
    vector<Edge> in[N];

    void addedge(int u, int v, int w) {
        ++idx;
        e[idx] = {u, v, w};
        ne[idx] = h[u];
        h[u] = idx;
        in[v].push_back({u, v, w});
    }
} graph;

struct {
    bool vis[N];
    Edge chose[N];
    int id[N], col[N], fr[N], stk[N], top;

    int findCol(int u) {
        return col[u] == u ? u : col[u] = findCol(col[u]);
    }

    int findRoot(int u) {
        return fr[u] == u ? u : fr[u] = findRoot(fr[u]);
    }

    int calc() {
        for (int i = 1; i <= n; i++) id[i] = heap.build(graph.in[i], 0, graph.in[i].size() - 1);
        for (int i = 1; i <= n; i++) col[i] = fr[i] = i;
        top = 0;
        for (int i = 1; i <= n; i++) if (i != s) stk[++top] = i;
        while (top) {
            int u = stk[top--];
            if (!id[u]) return -1;
            chose[u] = heap.val[id[u]];
            if (findRoot(findCol(chose[u].u)) == u) {
                int sum = chose[u].w;
                for (int i = findCol(chose[u].u); i != u; i = findCol(chose[i].u)) sum += chose[i].w;
                id[u] = heap.merge(heap.ls[id[u]], heap.rs[id[u]]);
                for (int i = findCol(chose[u].u); i != u; i = findCol(chose[i].u)) id[i] = heap.merge(heap.ls[id[i]], heap.rs[id[i]]);
                heap.addtag(id[u], sum - chose[u].w);
                for (int i = findCol(chose[u].u); i != u; i = findCol(chose[i].u)) heap.addtag(id[i], sum - chose[i].w);
                for (int i = findCol(chose[u].u); i != u; i = findCol(chose[i].u)) id[u] = heap.merge(id[u], id[i]);
                for (int i = findCol(chose[u].u); i != u; i = findCol(chose[i].u)) col[i] = u, fr[i] = u;
                fr[u] = u;
                stk[++top] = u;
            } else {
                fr[u] = chose[u].u;
            }
        }
        int res = 0;
        for (int i = 1; i <= n; i++) {
            if (!vis[findCol(i)]) {
                vis[findCol(i)] = true;
                res += chose[findCol(i)].w;
            }
        }
        return res;
    }
} dmst;

signed main() {
    n = read(), m = read(), s = read();
    for (int i = 1; i <= m; i++) {
        int u = read(), v = read(), w = read();
        if (u == v) continue;
        graph.addedge(u, v, w);
    }
    write(dmst.calc());
    flush;
    return 0;
}

例题选讲

P2792 [JSOI2008] 小店购物

因为每个商品都需要买多次,所以如果出现买 A 可以优惠买 B,买 B 可以优惠买 A,那么最优的策略就是先买一个 A,再买一个 B,最后把 A 买完。

题目中说只要买过一个商品一次,那么优惠就一直存在,没有数量限制,所以我们可以把一个商品拆成两个点,一个表示买了 1 个这个商品,另一个表示买了 m_i-1 个这个商品,这样就可以处理上面的例子。

然后从超级源点往每个商品拆出的节点 uvu 代表买 1 个,v 代表买 m_i-1 个)分别连一条权值为 c_ic_i\times(m_i-1) 的边;对于优惠 (A,B,p),从 u_Au_B、从 v_Bu_B 连一条权值为 p 的边,从 u_Av_B、从 v_Av_B 连一条权值为 p\times(m_B-1) 的边。

之后跑最小树形图就是最终答案。

因为数据比较小,所以朱刘算法就可以直接过。