Matrix-Tree 定理

· · 算法·理论

Matrix-Tree 定理证明

跟 @whjhr 不同的容斥证明。

摘自与 Ender 的QQ聊天讨论记录。

七天内有效

拜谢 Ender

2024.1.5 update

@Judgelight 说觉得有点抽象,要求翻译一下:

思路

不妨设算的是 所有边从叶子连向根 的生成树 的个数。

所以相当于要给每个除根以外的点找一个其要连向的父亲(删掉一行一列的含义)

但是不是每种方案都是合法的。可以发现,当且仅当不存在环时方案合法。

考虑容斥:至少 0 个环的方案 - 至少 1 个环的方案 + 至少 2 个环的方案 - ......

实现

Kirchhoff 矩阵 中 a_{i,i} 的含义即为 i 所有能选的父亲的方案数,-a_{i,j} 即为 i->j 的边的个数(j->i 也可以,不影响。后文皆理解为 i->j

考虑行列式计算中每个排列所代表的含义

如果第 i 行取了第 i 个数,那么代表 i 在能选的父亲中随便选了一个

否则第 i 行取了第 j 个数(j\ne i),那么其代表的是选了一条 i->j 的边。

相应的,第 j 行肯定也选了一个异于 j 的数 k(因为选的是排列),代表选了一条 j->k 的边。

而最终一定会回到 i,即形成了一个环。

所以推出这个排列乘出来数的绝对值 是形成某些环的方案的总数

那正负号为啥正确呢?

我们期望的符号是跟环的数量定的,环有偶数个则为正, 环有奇数个则为负

而实际计算中,符号由两个因素影响。一是 a 矩阵乘出来的,二是逆序对数

随便取的 $i$ 即 $a_{i,i}$,为正 环上的边 $i->j$ 即 $a_{i,j}$,为负 所以 $a$ 矩阵的贡献为 $\sum_{i=1}^k l_i

其中 l_i 为环 i 的长度,k 为环个数

逆序对:

对于随便取的 i,其他随便取的 j 与其没有贡献。而其他环中从 i 左边到右边的数量等于右边到左边的数量。所以含 i 的逆序对数为偶

对于环 i 来说,假设其的轮换顺序为:x_1,x_2,x_3,......,x_{l_i}。则交换 x_1x_{l_i} 位置上的数后其会变成 x_2,x_3,......,x_{l_i} 即一个长度减少 1 的环。且此次交换对逆序对的贡献为 1。所以将一个环消掉的贡献为 l_i - 1,而将所有环消掉后序列变为 1 ~ n。所以原序列的贡献为 \sum_{i=1}^k (l_i - 1)

总贡献为 \sum_{i=1}^k l_i - \sum_{i=1}^k (l_i - 1) = k,与期望相符

综上得证

2024.1.7 update

因为 @nodgd 说可以 O(n^3) 求这个玩意儿,就是把枚举根优化掉了,于是又去问了一波有了这个拓展,我愿称之为 exMatrix-Tree 定理 :

exMatrix-Tree 定理

发现给 Kirchhoff 矩阵的对角线每个数加一个变量 x 后要求的即是行列式求值后 x 的一次项系数,所以直接把原来每个数添上 x 这一维,并令 x^2 = 0 算行列式即可。

但不是很好算,因为多项式的消元要么变成实数域,要么整数辗转相除出不出来,比如 (x+1)/2

怎么办捏,发现此时将所有行(也可能是列,看怎么写的)相加并替换第一行(列)后全是 x,即常数项全消掉了,x 的系数也都是 1。这个新矩阵和原来的矩阵行列式是相等的,并且发现现在可以将 x 抹掉了!第一行(列)变成全 1 即可。

Code

#include <iostream>
typedef long long ll;
const ll N = 8;
ll read() {ll x; return scanf("%lld", &x), x;}
ll read1() {ll x; return scanf("%1lld", &x), x;}
ll a[N + 5][N + 5];
int main()
{
    ll n = read();
    for (ll i = 1; i <= n; i ++ )
        for (ll j = 1; j <= n; j ++ )
        {
            ll u = j, v = i, w = read1();
            a[u][u] += w, a[v][u] -= w;
        }
    for (ll i = 1; i <= n; i ++ ) a[1][i] = 1;
    ll f = 1;
    for (ll i = 1; i <= n; i ++ )
        for (ll j = i + 1; j <= n; j ++ )
        {
            while (a[i][i])
            {
                ll t = a[j][i] / a[i][i];
                for (ll k = i; k <= n; k ++ ) a[j][k] -= t * a[i][k];
                std :: swap(a[i], a[j]), f = -f;
            }
            std :: swap(a[i], a[j]), f = -f;
        }
    ll res = f;
    for (ll i = 1; i <= n; i ++ ) res *= a[i][i];
    printf("%lld", res);
    return 0;
}

2024.2.17 update

胡扯赛考到了 exMatrix-Tree,可以作为例题(

Code

#include <iostream>
#include <cstring>
#include <vector>
#include <cmath>
typedef long long ll;
const int N = 16;
const int M = 125;
const int mod = 1e9 + 7;
int read() {int x; return scanf("%d", &x), x;}
int u[M + 5], v[M + 5], w[M + 5];
int g[N + 5][N + 5];
int f[1 << N | 5], ans;
inline int det(int a[][N + 5], int n)
{
    for (int i = 1; i <= n; i ++ )
        for (int j = 1; j <= n; j ++ )
            (a[i][j] < 0) && (a[i][j] += mod);
    int res = 1;
    for (int i = 1; i <= n; i ++ )
    {
        int t = i;
        for (int j = i; j <= n; j ++ ) if (a[j][i]) {t = j; break;}
        if (!a[t][i]) return 0;
        if (t != i) res = mod - res, std :: swap(a[i], a[t]);
        for (int j = i + 1; j <= n; j ++ )
        {
            int x = i, y = j;
            while (a[y][i])
            {
                int t = mod - a[x][i] / a[y][i];
                for (int k = i; k <= n; k ++ ) a[x][k] = (a[x][k] + 1ll * a[y][k] * t) % mod;
                std :: swap(x, y);
            }
            if (x != i) std :: swap(a[x], a[i]), res = mod - res;
        }
        res = 1ll * res * a[i][i] % mod;
    }
    return res;
}
int main()
{
    int n = read(), m = read();
    for(int i = 1; i <= m; i ++ ) u[i] = read(), v[i] = read(), w[i] = read();
    for(int i = 0; i < 1 << n - 1; i ++ )
    {
        std :: memset(g, 0, sizeof g);
        for(int j = 1; j <= m; j ++ )
            if (i >> w[j] & 1)
                g[u[j]][v[j]] -- , g[v[j]][v[j]] ++ ;
        for (int j = 1; j <= n; j ++ ) g[1][j] = 1;
        f[i] = det(g, n);
        for (int j = i & (i - 1); j; j = j - 1 & i) (f[i] += mod - f[j]) >= mod && (f[i] -= mod);
        int t = 0;
        while (i >> t & 1) t ++ ;
        ans = (ans + 1ll * f[i] * t) % mod;
    }
    printf("%d\n", ans);
    return 0;
}