矩阵乘法的 Strassen 算法
Hexarhy
2021-06-13 18:12:26
## Update
1. [2021/11/14] 感谢 @serverkiller 指出错误并修正。
1. [2021/11/14] 替换了扩展阅读第二个 link。
# Introduction
Strassen 算法在 OI 上**没有任何应用价值**,不过了解一下理论计算机科学相关还蛮有意思的。
计算两个矩阵相乘的朴素方法是 $\Theta(n^3)$。而 Strassen 算法通过采用**分治**策略,并**减少递归次数**,实现在 $\Theta(n^{\log 7})$ 时间内完成。$\log 7\approx 2.808$,也就是 $O(n^{2.81})$。
**前置知识:** 矩阵基础,主定理。
其实主定理只需要知道结论即可,这里放一下简化结论:
$$
T(n)=aT\left(\dfrac{n}{b}\right)+O(n^d),a\ge 1,b>1\\
T(n)=\begin{cases}
O(n^d)&,d>\log_{b}a\\
O(n^d\log n)&,d=\log_b a\\
O(n^{\log_b a})&,d<\log_b a
\end{cases}
$$
# Analysis
### 从分治开始
在介绍 Strassen 算法之前,先探讨如何用分治完成矩阵乘法。
分治策略其实非常直截了当,就是将 $n\times n$ 的矩阵划分为四个 $\dfrac{n}{2}\times \dfrac{n}{2}$ 的矩阵。当然,这种做法要求 $n$ 是 $2$ 的幂,但相关细节我们稍后探讨。
先以 $2\times 2$ 的矩阵为例。
$$
\begin{bmatrix}A_{1,1}&A_{1,2}\\A_{2,1}&A_{2,2}\end{bmatrix}\cdot\begin{bmatrix}B_{1,1}&B_{1,2}\\B_{2,1}&B_{2,2}\end{bmatrix}=\begin{bmatrix}C_{1,1}&C_{1,2}\\C_{2,1}&C_{2,2}\end{bmatrix}
$$
具体写出计算矩阵 $C$ 的等式:
$$
C_{1,1}=A_{1,1}\cdot B_{1,1}+A_{1,2}\cdot B_{2,1}\\
C_{1,2}=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\
C_{2,1}=A_{2,1}\cdot B_{1,1}+A_{2,2}\cdot B_{2,1}\\
C_{2,2}=A_{2,1}\cdot B_{1,2}+A_{2,2}\cdot B_{2,2}\\
$$
对于每一条公式,相当于计算两对 $\dfrac{n}{2}\times\dfrac{n}{2}$ 矩阵乘法,再计算一次这样的矩阵加法。
直接利用这个运算方式就可以写出递归分治策略。形象化地说,就是将矩阵分解为左上,左下,右上,右下四个子矩阵再分别进行运算。
$$
\boxed{
\begin{array}{ll}
&\textbf {Function Multiply}(A,B)\\
1& n\gets A.\text{rows}\\
2& \text{let}\ C\ \text{be a new } n\times n\ \text{matrix}\\
3& \textbf{if } n=1\\
4& \qquad c_{1,1}\gets a_{1,1}\times b_{1,1}\\
5&\textbf{else}\ \mathrm{Partion}\ A,B,C\\
6&\qquad C_{1,1}\gets\text {Multiply}(A_{1,1},B_{1,1})+\text {Multiply}(A_{1,2},B_{2,1})\\
7&\qquad C_{1,2}\gets\text {Multiply}(A_{1,1},B_{1,2})+\text {Multiply}(A_{1,2},B_{2,2})\\
8&\qquad C_{2,1}\gets\text {Multiply}(A_{2,1},B_{1,1})+\text {Multiply}(A_{2,2},B_{2,1})\\
9&\qquad C_{2,2}\gets\text {Multiply}(A_{2,1},B_{1,2})+\text {Multiply}(A_{2,2},B_{2,2})\\
10&\textbf{return}\ C
\end{array}}
$$
注意 $\rm Partion$ 部分只需要对应好分解后的矩阵的下标即可。具体的对应方法可以参考 Strassen 算法伪代码。
需要说明的是,《算法导论》认为,只需通过下标计算即可对实现分解子矩阵并操作,而不用 $\Theta(n^2)$ 拷贝子矩阵。然而在亲自动手实现代码时,避免拷贝子矩阵来进行其他操作是异常麻烦的,书中也没有给出伪代码。况且拷贝也不影响总的时间复杂度,因为矩阵加法需要不可避免的 $\Theta(n^2)$。但拷贝操作对常数因子影响比较大。~~如果有谁会实现不用拷贝的请务必把代码发给我/kk~~
现在来看一下这段代码的时间复杂度。
我们分解出了 $8$ 个子问题,每个子问题规模缩小了一半,同时花了 $\Theta(n^2)$ 时间分解出子矩阵,花了 $\Theta(n^2)$ 进行矩阵加法。容易写出其递归式:
$$
\begin{aligned}
T(n)&=8T\left(\dfrac{n}{2}\right)+\Theta(n^2)
\end{aligned}
$$
运用主定理即可求解时间复杂度为 $\Theta(n^3)$。
没有优化啊?这就来到 Strassen 算法的另一个核心:减少递归次数。
### 还能再少一次
Strassen 算法在朴素分治算法的基础上,只进行了 $7$ 次递归。当然减少递归次数的代价就是多进行了几次矩阵加法,但幸好只是常数级别。
我们先对时间复杂度进行分析。递归式为:
$$
T(n)=7T\left(\dfrac{n}{2}\right)+\Theta(n^2)
$$
运用主定理求解出时间复杂度为 $\Theta(n^{\log_2 7})$,也就是 $O(n^{2.81})$。
步骤上就比朴素的分治算法要麻烦一些。
1. 与朴素分治算法相同,分解出左上,左下,右上,右下四个子矩阵。
2. 创建 $10$ 个 $\dfrac{n}2\times\dfrac n2$ 的矩阵 $S_i$,每个 $S_i$ 保存两个子矩阵的和或差。
3. 用子矩阵和 $S_i$ 相乘,递归地计算 $7$ 个 $\dfrac{n}2\times\dfrac n2$ 的矩阵 $P_i$。
4. 通过 $P_i$ 的不同组合进行加减,得到 $C$ 的子矩阵。
具体地,步骤 2 中创建的 $10$ 个矩阵 $S_i$ 分别为:
$$
\begin{array}{ll}
S_1&=B_{1,2}-B_{2,2}\\
S_2&=A_{1,1}+A_{1,2}\\
S_3&=A_{2,1}+A_{2,2}\\
S_4&=B_{2,1}-B_{1,1}\\
S_5&=A_{1,1}+A_{2,2}\\
S_6&=B_{1,1}+B_{2,2}\\
S_8&=A_{1,2}-A_{2,2}\\
S_8&=B_{2,1}+B_{2,2}\\
S_9&=A_{1,1}-A_{2,1}\\
S_{10}&=B_{1,1}+B_{1,2}
\end{array}
$$
步骤 3 中需要递归计算的 $7$ 个矩阵 $P_i$ 分别为:
$$
\begin{array}{ll}
P_1&=A_{1,1}\cdot S_1\\
P_2&=S_2\cdot B_{2,2}\\
P_3&=S_3\cdot B_{1,1}\\
P_4&=A_{2,2}\cdot S_4\\
P_5&=S_5\cdot S_6\\
P_6&=S_7\cdot S_8\\
P_7&=S_9\cdot S_{10}
\end{array}
$$
到了步骤 4,计算 $C$ 的子矩阵的方法为:
$$
\begin{array}{ll}
C_{1,1}&=P_5+P_4-P_2+P_6\\
C_{1,2}&=P_1+P_2\\
C_{2,1}&=P_3+P_4\\
C_{2,2}&=P_5+P_1-P_3-P_7
\end{array}
$$
这些式子为什么正确?直接代入即可验证。由于验证过程过于冗长,这里只举一例 $C_{1,2}=P_1+P_2$。
$$
\begin{aligned}
P_1+P_2&=A_{1,1}\cdot S_1+S_2\cdot B_{2,2}\\
&=A_{1,1}\cdot (B_{1,2}-B_{2,2})+(A_{1,1}+A_{1,2})\cdot B_{2,2}\\
&=A_{1,1}\cdot B_{1,2}-A_{1,1}\cdot B_{2,2}+A_{1,1}\cdot B_{2,2}+A_{1,2}\cdot B_{2,2}\\
&=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\
&=C_{1,2}
\end{aligned}
$$
至于 Strassen 具体是如何想到构造出这些算式的,则留给我们作为无限的遐想。有兴趣可以浏览[这里](https://softwareengineering.stackexchange.com/questions/199627/how-did-strassen-come-up-with-his-matrix-multiplication-method)。
# Exercises
节选自《算法导论》4.2 练习。
1. 试只用三次乘法完成复数相乘 $(a+b\mathrm{i})(c+d\mathrm{i})=(ac-bd)+(ad+bc)\mathrm{i}$。
设
$$
\begin{cases}\alpha=ac\\ \beta=bd\\\gamma=(a+b)(c+d)\\\end{cases}
$$
则:
$$
(a+b\mathrm{i})(c+d\mathrm{i})=(\alpha-\beta)+(\gamma-\alpha-\beta)\mathrm{i}
$$
计算 $\alpha,\beta,\gamma$ 只用了三次乘法,代价就是增加了加法次数。事实上 Gauss 早已发现了三次乘法进行复数相乘的方法,说不定 Strassen 是受到了这个启发?
2. 若矩阵规模不是 $2$ 的幂,如何应用 Strassen 算法?
用值 $0$ 补齐到 $2$ 的幂即可。这是实现代码时需要注意的地方。
3. 已知用 $k$ 次乘法操作完成两个 $3\times 3$ 的矩阵相乘,那么满足在 $o(n^{\log 7})$ 的时间内完成 $n\times n$ 的矩阵相乘,$k$ 的最大值是多少?
容易列出递归式并求解:
$$\begin{aligned}
T(n)&=kT\left(\dfrac n3\right)+O(n^2)\\
T(n)&=O(\log_3 k)\\
\log_3k&<\log_2 7\\
k_{\max}&=21
\end{aligned}$$
4. 编写 Strassen 算法的伪代码。
凑合着看吧,这里把 $\rm Partion$ 部分具体写了出来。
其中 $A[1\sim n/2][1\sim n/2]$ 表示由 $\forall i\in[1,n/2],\forall j\in[1,n/2],A_{i,j}$ 组成的子矩阵,其余类似。
$$
\boxed{
\begin{array}{ll}
&\textbf{Function Strassen}(A, B)\\
1& n \gets A.\mathrm{rows}\\
2& \textbf{if}\ n = 1\\
3& \qquad \textbf{return}\ a[1, 1]\times b[1, 1]\\
4& \mathrm{let}\ C\ \mathrm{be\ a\ new}\ n \times n\ \mathrm{matrix}\\
5& A[1, 1] \gets A[1\sim n / 2][1\sim n / 2]\\
6& A[1, 2] \gets A[1\sim n / 2][n / 2 + 1\sim n]\\
7& A[2, 1] \gets A[n / 2 + 1\sim n][1\sim n / 2]\\
8& A[2, 2] \gets A[n / 2 + 1\sim n][n / 2 + 1\sim n]\\
9& B[1, 1] \gets B[1\sim n / 2][1\sim n / 2]\\
10& B[1, 2] \gets B[1\sim n / 2][n / 2 + 1\sim n]\\
11& B[2, 1] \gets B[n / 2 + 1\sim n][1\sim n / 2]\\
12& B[2, 2] \gets B[n / 2 + 1\sim n][n / 2 + 1\sim n]\\
13& S[1] \gets B[1, 2] - B[2, 2]\\
14& S[2] \gets A[1, 1] + A[1, 2]\\
15& S[3] \gets A[2, 1] + A[2, 2]\\
16& S[4] \gets B[2, 1] - B[1, 1]\\
17& S[5] \gets A[1, 1] + A[2, 2]\\
18& S[6] \gets B[1, 1] + B[2, 2]\\
19& S[7] \gets A[1, 2] - A[2, 2]\\
20& S[8] \gets B[2, 1] + B[2, 2]\\
21& S[9] \gets A[1, 1] - A[2, 1]\\
22& S[10] \gets B[1, 1] + B[1, 2]\\
23& P[1] \gets \mathrm{Strassen}(A[1, 1], S[1])\\
24& P[2] \gets \mathrm{Strassen}(S[2], B[2, 2])\\
25& P[3] \gets \mathrm{Strassen}(S[3], B[1, 1])\\
26& P[4] \gets \mathrm{Strassen}(A[2, 2], S[4])\\
27& P[5] \gets \mathrm{Strassen}(S[5], S[6])\\
28& P[6] \gets \mathrm{Strassen}(S[7], S[8])\\
29& P[7] \gets \mathrm{Strassen}(S[9], S[10])\\
30& C[1\sim n / 2][1\sim n / 2] \gets P[5] + P[4] - P[2] + P[6]\\
31& C[1\sim n / 2][n / 2 + 1\sim n] \gets P[1] + P[2]\\
32& C[n / 2 + 1\sim n][1\sim n / 2] \gets P[3] + P[4]\\
33& C[n / 2 + 1\sim n][n / 2 + 1\sim n] \gets P[5] + P[1] - P[3] - P[7]\\
34& \textbf{return}\ C\\
\end{array}}
$$
~~至于 C++ 代码,就留作读者课后习题吧。~~
# Notice
然而很遗憾的是,Strassen 算法由于使用了大量递归,多次创建临时矩阵,常数因子比朴素矩阵乘法大非常多,因此 OI 范围内几乎没有应用价值。朴素的矩阵乘法就足够了。
如果确实想要优化常数,可以从以下方面考虑:
- 使用指令集等卡常科技。
- 在子矩阵小到一定规模时使用朴素矩阵乘法。
- 尽可能地减少创建临时矩阵和复制次数。
据 w33z8kqrqk8zzzx33 在[帖子](https://www.luogu.com.cn/discuss/show/321588)里说道,使用指令集可以在 $n=2^{10}$ 时用少于 $0.2$ s 的时间完成(朴素矩阵乘法用指令集应该也会快不少吧?)。
而实际应用中的大型矩阵乘法都依赖于硬件(cache, GPU 等)和分布式计算。
尽管 Strassen 算法看上去没有什么实际用处,《算法导论》依然指出了 Strassen 算法的一个最重要的意义:在理论研究上作出了突破性的贡献。
类比 1959 年发明的 Shell Sort(希尔排序),目前来看也没有任何应用价值,但这是计算机第一次在整数排序上突破了 $\Theta(n^2)$ 的壁障(当时常见的还是插入排序和冒泡排序)。
同样地,是 Strassen 算法使得矩阵乘法在渐进上界上第一次快于 $\Theta(n^3)$,并鼓舞着后人继续在这方面探索。1969 年 Strassen 发表论文的的标题为《高斯消元法并非最优》,也正揭示了该算法的意义所在。
# Extension
Strassen 算法其实能推广到很多矩阵操作,在原论文也略有提及。其基本思路都是分治并减少递归次数。本文将对其粗略说明。至于 Strassen 构造的这些玄妙算式是如何想到的,则给后人留下了神秘的美感。
下面所有算法的时间复杂度分析都与前文矩阵乘法相类似,不再赘述。
## 矩阵求逆
一般的矩阵求逆是用 $\Theta(n^3)$ 的高斯消元法完成的。而 Strassen 在矩阵乘法的基础上,得到了更快的算法。这也是 Strassen 论文标题的由来。
基本思路依然与矩阵乘法类似。
先将矩阵 $A$ 按照左上,左下,右上,右下分解出子矩阵。
创建如下 $7$ 个矩阵 $P_i$:
$$
\begin{array}{ll}
P_1&=A_{1,1}^{-1}\\
P_2&=A_{2,1}\cdot P_1\\
P_3&=P_{1}\cdot A_{1,2}\\
P_4&=A_{2,1}\cdot P_3\\
P_5&=P_4-A_{2,2}\\
P_6&=P_5^{-1}\\
P_7&=P_{3}\cdot P_6\cdot P_2\\
\end{array}
$$
然后计算出 $A$ 的逆矩阵 $A^{-1}$ 的 $4$ 个子矩阵:
$$
\begin{array}{ll}
A^{-1}_{1,1}&=P_1-P_7\\
A^{-1}_{1,2}&=P_3\cdot P_6\\
A^{-1}_{2,1}&=P_6\cdot P_2\\
A^{-1}_{2,2}&=-P_6
\end{array}
$$
当子矩阵缩小到一定规模,我们就可以直接用高斯消元法求解来减小常数。
当然,递归求逆过程中也要顺便用 Strassen 算法求矩阵乘法。
需要说明的是,Strassen 假定了操作过程中所有矩阵都是可逆的,而对于更复杂的情况则束手无策。对于 Strassen 算法在矩阵求逆操作上的更深研究,限于篇幅请参考[这里](https://arxiv.org/pdf/1901.00904.pdf)。
## 解线性方程组
> Similar results hold for solving a system of linear equations or computing a
determinant.
Strassen 对于矩阵求逆和行列式计算的记录相当简略啊。
我们知道线性方程组可以写成矩阵的形式。
$$
\begin{bmatrix}
a_{1,1}&a_{1,2}&a_{1,3}&\cdots& a_{1,m}\\
a_{2,1}&a_{2,2}&a_{2,3}&\cdots& a_{2,m}\\
a_{3,1}&a_{3,2}&a_{3,3}&\cdots& a_{3,m}\\
\vdots&\vdots&\vdots&\ddots&\vdots\\
a_{n,1}&a_{n,2}&a_{n,3}&\cdots& a_{n,m}\\
\end{bmatrix}
\cdot
\begin{bmatrix}
x_1\\
x_2\\
x_3\\
\vdots\\
x_n
\end{bmatrix}=
\begin{bmatrix}
y_1\\
y_2\\
y_3\\
\vdots\\
y_n
\end{bmatrix}
$$
上式简记为 $A\cdot x=B$,则有 $x=A^{-1}\cdot
B$。
如果逆矩阵存在,则方程恰好有一解,直接用 Strassen 算法进行矩阵求逆即可。
## 行列式
同样把矩阵 $A$ 分解成四个子矩阵,然后用一种简洁的计算方式求解即可:
$$
\operatorname{det}(A)=\operatorname{det}(A_{1,1})\cdot \operatorname{det}\left(A_{2,2}-A_{2,1}\cdot A_{1,1}^{-1}\cdot A_{1,2}\right)
$$
# Reference
- 《算法导论》第三版第四章。
- [Strassen 原论文 - _Gaussian Elimination is not Optimal_](http://scgroup.hpclab.ceid.upatras.gr/class/SC/Papers/Strassen.pdf)
关于矩阵乘法的其它研究,感兴趣的可以阅读以下文献:
1. 用 $O(n^{2.37286})$ 时间完成矩阵乘法 - [Link](https://arxiv.org/pdf/2010.05846.pdf)
2. 对于 $n\times n$ 的矩阵左乘上 $n\times n^p(p\le 0.294)$ 的矩阵,可以做到 $O(n^2)$ 的时间复杂度 - [Link](https://pdf.sciencedirectassets.com/272569/1-s2.0-S0885064X00X00374/1-s2.0-S0885064X97904386/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEJv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQD5OIkH%2FHrGP2VSizxsPWtgdganIIa%2FB93f%2FOO%2FooGjiQIgd6aRmE4mA2vIUvv%2B%2BHBxszexkiqMVl7haUxj%2FFlaO1kq%2BgMIUxAEGgwwNTkwMDM1NDY4NjUiDMpNvI3WPqYT0HQv%2FirXAzjMtqcxfTj3gTbdL%2FCPQIrJZZVwkh63%2F1GxZO5oeVdVNZmJfvfVZPprzYFvSlF29uSig4Yle1I%2Fs6yu7c3KHQ3vDHo%2F8NUta0zSHTbY9HfQxz%2Bp2uZiQGYuSZhJplFO%2BYnlyk9Sz6b3AZI1SX%2B8gX2KzSW1zsEfiAjzd5ChktRMovbjuHEZ6otoDq0iuNUkGQemZ5KsTvtZFDUkZizulRBj3OWjSgxu%2Bux3Gh45MXKZ%2B6jghRuEP0shfljhxLgBiDoKzpRcwIQTqVQ0UJxx3neAdo1xcZ2SR%2Bwjdi6%2FoYecK1IDT57huxV1x3gNyYbFLmp3G52KaCNF35uxCLVxZBik%2Fudv31lAjnH%2B3x7IAuiEUUO8ATOYjDualBvswHEPs6e3qs1x5StWlMWDdqm5urBbOtkohIJS0S3t6yvj33tzlOfPrQKxZnVE4XALjNOr6mn%2BDCInKgRutsaGgSwp%2FPUIVL3zWXRYy5WsMAfkG8F5pkmt1zW543EySVIhAV8CbZ9MoAYcaT6TgI00WFIKcVq70Hu4cfZFaaMRnX3190pB1BIBP8xC4zvhdQcKmrmovD84a10g51oBKdbQTZMq8zkKU%2B%2F51B1enL8GLSvutlkLXeBBw9s8mTCH38GMBjqlAc8Y3w9sTgUMZptwWgyXyWkL2Yv38TTAF%2FmSDnYKmcLoUZy4ajgrb52NyaNh60g1BJw1TDuLLA02HoHsoyo%2FE2ydPKtRSgMRDDGU5gyVPpK249T%2FzUkDSXOmgzmIqEwa3FswC1oHY%2BwYyKfiXsU7wHVz1GFbh3yPQ5enu%2FczBSDk2M2uF3qc6Id5kuS03pEiknP3q0aMHUGSy0wt27LyFnHx%2F8KeNg%3D%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20211114T035342Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY6CG4QTXU%2F20211114%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=6ce6eb21e25cc05bc3e1f21aea5ecd4934ebeea6e42214b9449c170727fd5339&hash=2ee762bfa2b44c1868fcb13aeeb0e8891aec38d0c96941b284a0711565b32078&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0885064X97904386&tid=spdf-6c4f336c-6d87-4148-bd22-1cb4d22fe9fc&sid=cec234191076e046d48945896fd4d48b5d95gxrqa&type=client)
3. 关于 GPU 上进行矩阵乘法的效率 - [Link](http://graphics.stanford.edu/papers/gpumatrixmult/gpumatrixmult.pdf)