矩阵乘法的 Strassen 算法

Hexarhy

2021-06-13 18:12:26

Personal

## Update 1. [2021/11/14] 感谢 @serverkiller 指出错误并修正。 1. [2021/11/14] 替换了扩展阅读第二个 link。 # Introduction Strassen 算法在 OI 上**没有任何应用价值**,不过了解一下理论计算机科学相关还蛮有意思的。 计算两个矩阵相乘的朴素方法是 $\Theta(n^3)$。而 Strassen 算法通过采用**分治**策略,并**减少递归次数**,实现在 $\Theta(n^{\log 7})$ 时间内完成。$\log 7\approx 2.808$,也就是 $O(n^{2.81})$。 **前置知识:** 矩阵基础,主定理。 其实主定理只需要知道结论即可,这里放一下简化结论: $$ T(n)=aT\left(\dfrac{n}{b}\right)+O(n^d),a\ge 1,b>1\\ T(n)=\begin{cases} O(n^d)&,d>\log_{b}a\\ O(n^d\log n)&,d=\log_b a\\ O(n^{\log_b a})&,d<\log_b a \end{cases} $$ # Analysis ### 从分治开始 在介绍 Strassen 算法之前,先探讨如何用分治完成矩阵乘法。 分治策略其实非常直截了当,就是将 $n\times n$ 的矩阵划分为四个 $\dfrac{n}{2}\times \dfrac{n}{2}$ 的矩阵。当然,这种做法要求 $n$ 是 $2$ 的幂,但相关细节我们稍后探讨。 先以 $2\times 2$ 的矩阵为例。 $$ \begin{bmatrix}A_{1,1}&A_{1,2}\\A_{2,1}&A_{2,2}\end{bmatrix}\cdot\begin{bmatrix}B_{1,1}&B_{1,2}\\B_{2,1}&B_{2,2}\end{bmatrix}=\begin{bmatrix}C_{1,1}&C_{1,2}\\C_{2,1}&C_{2,2}\end{bmatrix} $$ 具体写出计算矩阵 $C$ 的等式: $$ C_{1,1}=A_{1,1}\cdot B_{1,1}+A_{1,2}\cdot B_{2,1}\\ C_{1,2}=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\ C_{2,1}=A_{2,1}\cdot B_{1,1}+A_{2,2}\cdot B_{2,1}\\ C_{2,2}=A_{2,1}\cdot B_{1,2}+A_{2,2}\cdot B_{2,2}\\ $$ 对于每一条公式,相当于计算两对 $\dfrac{n}{2}\times\dfrac{n}{2}$ 矩阵乘法,再计算一次这样的矩阵加法。 直接利用这个运算方式就可以写出递归分治策略。形象化地说,就是将矩阵分解为左上,左下,右上,右下四个子矩阵再分别进行运算。 $$ \boxed{ \begin{array}{ll} &\textbf {Function Multiply}(A,B)\\ 1& n\gets A.\text{rows}\\ 2& \text{let}\ C\ \text{be a new } n\times n\ \text{matrix}\\ 3& \textbf{if } n=1\\ 4& \qquad c_{1,1}\gets a_{1,1}\times b_{1,1}\\ 5&\textbf{else}\ \mathrm{Partion}\ A,B,C\\ 6&\qquad C_{1,1}\gets\text {Multiply}(A_{1,1},B_{1,1})+\text {Multiply}(A_{1,2},B_{2,1})\\ 7&\qquad C_{1,2}\gets\text {Multiply}(A_{1,1},B_{1,2})+\text {Multiply}(A_{1,2},B_{2,2})\\ 8&\qquad C_{2,1}\gets\text {Multiply}(A_{2,1},B_{1,1})+\text {Multiply}(A_{2,2},B_{2,1})\\ 9&\qquad C_{2,2}\gets\text {Multiply}(A_{2,1},B_{1,2})+\text {Multiply}(A_{2,2},B_{2,2})\\ 10&\textbf{return}\ C \end{array}} $$ 注意 $\rm Partion$ 部分只需要对应好分解后的矩阵的下标即可。具体的对应方法可以参考 Strassen 算法伪代码。 需要说明的是,《算法导论》认为,只需通过下标计算即可对实现分解子矩阵并操作,而不用 $\Theta(n^2)$ 拷贝子矩阵。然而在亲自动手实现代码时,避免拷贝子矩阵来进行其他操作是异常麻烦的,书中也没有给出伪代码。况且拷贝也不影响总的时间复杂度,因为矩阵加法需要不可避免的 $\Theta(n^2)$。但拷贝操作对常数因子影响比较大。~~如果有谁会实现不用拷贝的请务必把代码发给我/kk~~ 现在来看一下这段代码的时间复杂度。 我们分解出了 $8$ 个子问题,每个子问题规模缩小了一半,同时花了 $\Theta(n^2)$ 时间分解出子矩阵,花了 $\Theta(n^2)$ 进行矩阵加法。容易写出其递归式: $$ \begin{aligned} T(n)&=8T\left(\dfrac{n}{2}\right)+\Theta(n^2) \end{aligned} $$ 运用主定理即可求解时间复杂度为 $\Theta(n^3)$。 没有优化啊?这就来到 Strassen 算法的另一个核心:减少递归次数。 ### 还能再少一次 Strassen 算法在朴素分治算法的基础上,只进行了 $7$ 次递归。当然减少递归次数的代价就是多进行了几次矩阵加法,但幸好只是常数级别。 我们先对时间复杂度进行分析。递归式为: $$ T(n)=7T\left(\dfrac{n}{2}\right)+\Theta(n^2) $$ 运用主定理求解出时间复杂度为 $\Theta(n^{\log_2 7})$,也就是 $O(n^{2.81})$。 步骤上就比朴素的分治算法要麻烦一些。 1. 与朴素分治算法相同,分解出左上,左下,右上,右下四个子矩阵。 2. 创建 $10$ 个 $\dfrac{n}2\times\dfrac n2$ 的矩阵 $S_i$,每个 $S_i$ 保存两个子矩阵的和或差。 3. 用子矩阵和 $S_i$ 相乘,递归地计算 $7$ 个 $\dfrac{n}2\times\dfrac n2$ 的矩阵 $P_i$。 4. 通过 $P_i$ 的不同组合进行加减,得到 $C$ 的子矩阵。 具体地,步骤 2 中创建的 $10$ 个矩阵 $S_i$ 分别为: $$ \begin{array}{ll} S_1&=B_{1,2}-B_{2,2}\\ S_2&=A_{1,1}+A_{1,2}\\ S_3&=A_{2,1}+A_{2,2}\\ S_4&=B_{2,1}-B_{1,1}\\ S_5&=A_{1,1}+A_{2,2}\\ S_6&=B_{1,1}+B_{2,2}\\ S_8&=A_{1,2}-A_{2,2}\\ S_8&=B_{2,1}+B_{2,2}\\ S_9&=A_{1,1}-A_{2,1}\\ S_{10}&=B_{1,1}+B_{1,2} \end{array} $$ 步骤 3 中需要递归计算的 $7$ 个矩阵 $P_i$ 分别为: $$ \begin{array}{ll} P_1&=A_{1,1}\cdot S_1\\ P_2&=S_2\cdot B_{2,2}\\ P_3&=S_3\cdot B_{1,1}\\ P_4&=A_{2,2}\cdot S_4\\ P_5&=S_5\cdot S_6\\ P_6&=S_7\cdot S_8\\ P_7&=S_9\cdot S_{10} \end{array} $$ 到了步骤 4,计算 $C$ 的子矩阵的方法为: $$ \begin{array}{ll} C_{1,1}&=P_5+P_4-P_2+P_6\\ C_{1,2}&=P_1+P_2\\ C_{2,1}&=P_3+P_4\\ C_{2,2}&=P_5+P_1-P_3-P_7 \end{array} $$ 这些式子为什么正确?直接代入即可验证。由于验证过程过于冗长,这里只举一例 $C_{1,2}=P_1+P_2$。 $$ \begin{aligned} P_1+P_2&=A_{1,1}\cdot S_1+S_2\cdot B_{2,2}\\ &=A_{1,1}\cdot (B_{1,2}-B_{2,2})+(A_{1,1}+A_{1,2})\cdot B_{2,2}\\ &=A_{1,1}\cdot B_{1,2}-A_{1,1}\cdot B_{2,2}+A_{1,1}\cdot B_{2,2}+A_{1,2}\cdot B_{2,2}\\ &=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\ &=C_{1,2} \end{aligned} $$ 至于 Strassen 具体是如何想到构造出这些算式的,则留给我们作为无限的遐想。有兴趣可以浏览[这里](https://softwareengineering.stackexchange.com/questions/199627/how-did-strassen-come-up-with-his-matrix-multiplication-method)。 # Exercises 节选自《算法导论》4.2 练习。 1. 试只用三次乘法完成复数相乘 $(a+b\mathrm{i})(c+d\mathrm{i})=(ac-bd)+(ad+bc)\mathrm{i}$。 设 $$ \begin{cases}\alpha=ac\\ \beta=bd\\\gamma=(a+b)(c+d)\\\end{cases} $$ 则: $$ (a+b\mathrm{i})(c+d\mathrm{i})=(\alpha-\beta)+(\gamma-\alpha-\beta)\mathrm{i} $$ 计算 $\alpha,\beta,\gamma$ 只用了三次乘法,代价就是增加了加法次数。事实上 Gauss 早已发现了三次乘法进行复数相乘的方法,说不定 Strassen 是受到了这个启发? 2. 若矩阵规模不是 $2$ 的幂,如何应用 Strassen 算法? 用值 $0$ 补齐到 $2$ 的幂即可。这是实现代码时需要注意的地方。 3. 已知用 $k$ 次乘法操作完成两个 $3\times 3$ 的矩阵相乘,那么满足在 $o(n^{\log 7})$ 的时间内完成 $n\times n$ 的矩阵相乘,$k$ 的最大值是多少? 容易列出递归式并求解: $$\begin{aligned} T(n)&=kT\left(\dfrac n3\right)+O(n^2)\\ T(n)&=O(\log_3 k)\\ \log_3k&<\log_2 7\\ k_{\max}&=21 \end{aligned}$$ 4. 编写 Strassen 算法的伪代码。 凑合着看吧,这里把 $\rm Partion$ 部分具体写了出来。 其中 $A[1\sim n/2][1\sim n/2]$ 表示由 $\forall i\in[1,n/2],\forall j\in[1,n/2],A_{i,j}$ 组成的子矩阵,其余类似。 $$ \boxed{ \begin{array}{ll} &\textbf{Function Strassen}(A, B)\\ 1& n \gets A.\mathrm{rows}\\ 2& \textbf{if}\ n = 1\\ 3& \qquad \textbf{return}\ a[1, 1]\times b[1, 1]\\ 4& \mathrm{let}\ C\ \mathrm{be\ a\ new}\ n \times n\ \mathrm{matrix}\\ 5& A[1, 1] \gets A[1\sim n / 2][1\sim n / 2]\\ 6& A[1, 2] \gets A[1\sim n / 2][n / 2 + 1\sim n]\\ 7& A[2, 1] \gets A[n / 2 + 1\sim n][1\sim n / 2]\\ 8& A[2, 2] \gets A[n / 2 + 1\sim n][n / 2 + 1\sim n]\\ 9& B[1, 1] \gets B[1\sim n / 2][1\sim n / 2]\\ 10& B[1, 2] \gets B[1\sim n / 2][n / 2 + 1\sim n]\\ 11& B[2, 1] \gets B[n / 2 + 1\sim n][1\sim n / 2]\\ 12& B[2, 2] \gets B[n / 2 + 1\sim n][n / 2 + 1\sim n]\\ 13& S[1] \gets B[1, 2] - B[2, 2]\\ 14& S[2] \gets A[1, 1] + A[1, 2]\\ 15& S[3] \gets A[2, 1] + A[2, 2]\\ 16& S[4] \gets B[2, 1] - B[1, 1]\\ 17& S[5] \gets A[1, 1] + A[2, 2]\\ 18& S[6] \gets B[1, 1] + B[2, 2]\\ 19& S[7] \gets A[1, 2] - A[2, 2]\\ 20& S[8] \gets B[2, 1] + B[2, 2]\\ 21& S[9] \gets A[1, 1] - A[2, 1]\\ 22& S[10] \gets B[1, 1] + B[1, 2]\\ 23& P[1] \gets \mathrm{Strassen}(A[1, 1], S[1])\\ 24& P[2] \gets \mathrm{Strassen}(S[2], B[2, 2])\\ 25& P[3] \gets \mathrm{Strassen}(S[3], B[1, 1])\\ 26& P[4] \gets \mathrm{Strassen}(A[2, 2], S[4])\\ 27& P[5] \gets \mathrm{Strassen}(S[5], S[6])\\ 28& P[6] \gets \mathrm{Strassen}(S[7], S[8])\\ 29& P[7] \gets \mathrm{Strassen}(S[9], S[10])\\ 30& C[1\sim n / 2][1\sim n / 2] \gets P[5] + P[4] - P[2] + P[6]\\ 31& C[1\sim n / 2][n / 2 + 1\sim n] \gets P[1] + P[2]\\ 32& C[n / 2 + 1\sim n][1\sim n / 2] \gets P[3] + P[4]\\ 33& C[n / 2 + 1\sim n][n / 2 + 1\sim n] \gets P[5] + P[1] - P[3] - P[7]\\ 34& \textbf{return}\ C\\ \end{array}} $$ ~~至于 C++ 代码,就留作读者课后习题吧。~~ # Notice 然而很遗憾的是,Strassen 算法由于使用了大量递归,多次创建临时矩阵,常数因子比朴素矩阵乘法大非常多,因此 OI 范围内几乎没有应用价值。朴素的矩阵乘法就足够了。 如果确实想要优化常数,可以从以下方面考虑: - 使用指令集等卡常科技。 - 在子矩阵小到一定规模时使用朴素矩阵乘法。 - 尽可能地减少创建临时矩阵和复制次数。 据 w33z8kqrqk8zzzx33 在[帖子](https://www.luogu.com.cn/discuss/show/321588)里说道,使用指令集可以在 $n=2^{10}$ 时用少于 $0.2$ s 的时间完成(朴素矩阵乘法用指令集应该也会快不少吧?)。 而实际应用中的大型矩阵乘法都依赖于硬件(cache, GPU 等)和分布式计算。 尽管 Strassen 算法看上去没有什么实际用处,《算法导论》依然指出了 Strassen 算法的一个最重要的意义:在理论研究上作出了突破性的贡献。 类比 1959 年发明的 Shell Sort(希尔排序),目前来看也没有任何应用价值,但这是计算机第一次在整数排序上突破了 $\Theta(n^2)$ 的壁障(当时常见的还是插入排序和冒泡排序)。 同样地,是 Strassen 算法使得矩阵乘法在渐进上界上第一次快于 $\Theta(n^3)$,并鼓舞着后人继续在这方面探索。1969 年 Strassen 发表论文的的标题为《高斯消元法并非最优》,也正揭示了该算法的意义所在。 # Extension Strassen 算法其实能推广到很多矩阵操作,在原论文也略有提及。其基本思路都是分治并减少递归次数。本文将对其粗略说明。至于 Strassen 构造的这些玄妙算式是如何想到的,则给后人留下了神秘的美感。 下面所有算法的时间复杂度分析都与前文矩阵乘法相类似,不再赘述。 ## 矩阵求逆 一般的矩阵求逆是用 $\Theta(n^3)$ 的高斯消元法完成的。而 Strassen 在矩阵乘法的基础上,得到了更快的算法。这也是 Strassen 论文标题的由来。 基本思路依然与矩阵乘法类似。 先将矩阵 $A$ 按照左上,左下,右上,右下分解出子矩阵。 创建如下 $7$ 个矩阵 $P_i$: $$ \begin{array}{ll} P_1&=A_{1,1}^{-1}\\ P_2&=A_{2,1}\cdot P_1\\ P_3&=P_{1}\cdot A_{1,2}\\ P_4&=A_{2,1}\cdot P_3\\ P_5&=P_4-A_{2,2}\\ P_6&=P_5^{-1}\\ P_7&=P_{3}\cdot P_6\cdot P_2\\ \end{array} $$ 然后计算出 $A$ 的逆矩阵 $A^{-1}$ 的 $4$ 个子矩阵: $$ \begin{array}{ll} A^{-1}_{1,1}&=P_1-P_7\\ A^{-1}_{1,2}&=P_3\cdot P_6\\ A^{-1}_{2,1}&=P_6\cdot P_2\\ A^{-1}_{2,2}&=-P_6 \end{array} $$ 当子矩阵缩小到一定规模,我们就可以直接用高斯消元法求解来减小常数。 当然,递归求逆过程中也要顺便用 Strassen 算法求矩阵乘法。 需要说明的是,Strassen 假定了操作过程中所有矩阵都是可逆的,而对于更复杂的情况则束手无策。对于 Strassen 算法在矩阵求逆操作上的更深研究,限于篇幅请参考[这里](https://arxiv.org/pdf/1901.00904.pdf)。 ## 解线性方程组 > Similar results hold for solving a system of linear equations or computing a determinant. Strassen 对于矩阵求逆和行列式计算的记录相当简略啊。 我们知道线性方程组可以写成矩阵的形式。 $$ \begin{bmatrix} a_{1,1}&a_{1,2}&a_{1,3}&\cdots& a_{1,m}\\ a_{2,1}&a_{2,2}&a_{2,3}&\cdots& a_{2,m}\\ a_{3,1}&a_{3,2}&a_{3,3}&\cdots& a_{3,m}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{n,1}&a_{n,2}&a_{n,3}&\cdots& a_{n,m}\\ \end{bmatrix} \cdot \begin{bmatrix} x_1\\ x_2\\ x_3\\ \vdots\\ x_n \end{bmatrix}= \begin{bmatrix} y_1\\ y_2\\ y_3\\ \vdots\\ y_n \end{bmatrix} $$ 上式简记为 $A\cdot x=B$,则有 $x=A^{-1}\cdot B$。 如果逆矩阵存在,则方程恰好有一解,直接用 Strassen 算法进行矩阵求逆即可。 ## 行列式 同样把矩阵 $A$ 分解成四个子矩阵,然后用一种简洁的计算方式求解即可: $$ \operatorname{det}(A)=\operatorname{det}(A_{1,1})\cdot \operatorname{det}\left(A_{2,2}-A_{2,1}\cdot A_{1,1}^{-1}\cdot A_{1,2}\right) $$ # Reference - 《算法导论》第三版第四章。 - [Strassen 原论文 - _Gaussian Elimination is not Optimal_](http://scgroup.hpclab.ceid.upatras.gr/class/SC/Papers/Strassen.pdf) 关于矩阵乘法的其它研究,感兴趣的可以阅读以下文献: 1. 用 $O(n^{2.37286})$ 时间完成矩阵乘法 - [Link](https://arxiv.org/pdf/2010.05846.pdf) 2. 对于 $n\times n$ 的矩阵左乘上 $n\times n^p(p\le 0.294)$ 的矩阵,可以做到 $O(n^2)$ 的时间复杂度 - [Link](https://pdf.sciencedirectassets.com/272569/1-s2.0-S0885064X00X00374/1-s2.0-S0885064X97904386/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEJv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQD5OIkH%2FHrGP2VSizxsPWtgdganIIa%2FB93f%2FOO%2FooGjiQIgd6aRmE4mA2vIUvv%2B%2BHBxszexkiqMVl7haUxj%2FFlaO1kq%2BgMIUxAEGgwwNTkwMDM1NDY4NjUiDMpNvI3WPqYT0HQv%2FirXAzjMtqcxfTj3gTbdL%2FCPQIrJZZVwkh63%2F1GxZO5oeVdVNZmJfvfVZPprzYFvSlF29uSig4Yle1I%2Fs6yu7c3KHQ3vDHo%2F8NUta0zSHTbY9HfQxz%2Bp2uZiQGYuSZhJplFO%2BYnlyk9Sz6b3AZI1SX%2B8gX2KzSW1zsEfiAjzd5ChktRMovbjuHEZ6otoDq0iuNUkGQemZ5KsTvtZFDUkZizulRBj3OWjSgxu%2Bux3Gh45MXKZ%2B6jghRuEP0shfljhxLgBiDoKzpRcwIQTqVQ0UJxx3neAdo1xcZ2SR%2Bwjdi6%2FoYecK1IDT57huxV1x3gNyYbFLmp3G52KaCNF35uxCLVxZBik%2Fudv31lAjnH%2B3x7IAuiEUUO8ATOYjDualBvswHEPs6e3qs1x5StWlMWDdqm5urBbOtkohIJS0S3t6yvj33tzlOfPrQKxZnVE4XALjNOr6mn%2BDCInKgRutsaGgSwp%2FPUIVL3zWXRYy5WsMAfkG8F5pkmt1zW543EySVIhAV8CbZ9MoAYcaT6TgI00WFIKcVq70Hu4cfZFaaMRnX3190pB1BIBP8xC4zvhdQcKmrmovD84a10g51oBKdbQTZMq8zkKU%2B%2F51B1enL8GLSvutlkLXeBBw9s8mTCH38GMBjqlAc8Y3w9sTgUMZptwWgyXyWkL2Yv38TTAF%2FmSDnYKmcLoUZy4ajgrb52NyaNh60g1BJw1TDuLLA02HoHsoyo%2FE2ydPKtRSgMRDDGU5gyVPpK249T%2FzUkDSXOmgzmIqEwa3FswC1oHY%2BwYyKfiXsU7wHVz1GFbh3yPQ5enu%2FczBSDk2M2uF3qc6Id5kuS03pEiknP3q0aMHUGSy0wt27LyFnHx%2F8KeNg%3D%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20211114T035342Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY6CG4QTXU%2F20211114%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=6ce6eb21e25cc05bc3e1f21aea5ecd4934ebeea6e42214b9449c170727fd5339&hash=2ee762bfa2b44c1868fcb13aeeb0e8891aec38d0c96941b284a0711565b32078&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S0885064X97904386&tid=spdf-6c4f336c-6d87-4148-bd22-1cb4d22fe9fc&sid=cec234191076e046d48945896fd4d48b5d95gxrqa&type=client) 3. 关于 GPU 上进行矩阵乘法的效率 - [Link](http://graphics.stanford.edu/papers/gpumatrixmult/gpumatrixmult.pdf)