测测你的 01 矩阵乘法

· · Algo. & Theory

测测你的 01 矩阵乘法!

这里的 01 矩阵乘法是指模 2 意义下的矩阵乘法,也即加法为异或、乘法为与的矩阵乘法。

暴力容易做到 \mathcal O(n^3)

for (int i = 0; i < N; i++) {
    for (int k = 0; k < N; k++) {
        for (int j = 0; j < N; j++) c[i][j] ^= a[i][k] & b[k][j];
    }
}

注意循环顺序:不仅有助于内存连续访问,更有助于想出进一步的优化。

注意到上面的代码等价于:

for (int i = 0; i < N; i++) {
    for (int k = 0; k < N; k++) {
        for (int j = 0; j < N; j++) if (a[i][k]) c[i][j] ^= b[k][j];
    }
}

容易想到使用 bitset 优化:

for (int i = 0; i < N; i++) {
    for (int k = 0; k < N; k++) if (a[i][k]) c[i] ^= b[k];
}

这样复杂度就做到了 \mathcal O\left(\dfrac{n^3}{\omega}\right)

考虑进一步的优化。

k 这一维分块。设块长为 L,每个块 [l,r] 内处理出 2^L 种矩阵 B 的行向量的选取方式对应的异或和。

然后计算该块对 A\times B 的每一行的贡献,对于第 i 行,直接看 A_{il}\sim A_{ir} 的取值情况即可。

复杂度为 \mathcal O\left(\dfrac{n^22^L}{\omega L}+\dfrac{n^3}{\omega L}\right),取 L=\log_2n,有最优复杂度 \mathcal O\left(\dfrac{n^3}{\omega\log n}\right)

这个代码很好写。

for (int l = 0, r = B; l ^ N; l = r, r += B) {
    if (__builtin_expect(r > N, 0)) r = N;
    for (int i = 1; i ^ N; i++) f[i] = f[i ^ Lb[i]] ^ b[l + Lg[Lb[i]]];
    for (int i = 0, k = 0; i ^ N; i++, k = 0) {
        for (int j = l; j ^ r; j++) k ^= a[i][j] << j - l;
        c[i] ^= f[k];
    }
}

跑得还是很快的,4096 只跑了四分之一秒左右。