矩阵乘法: 从 O(n^3) 到 O(n^2.78)
wind_boy
·
·
个人记录
翻译自 矩阵乘法: 从 Strassen 到 Coppersmith–Winograd - EI 的前 1/3,并进行了一些扩充与人性化.
由于时间原因,本文的终点是证明 \omega\leq 2.7799.
张量描述
固定域 F,我们称 F 上一个双线性计算问题是对于
f_k=\sum_{i=1}^N\sum_{j=1}^Mt_{i,j,k}x_iy_j\pod{1\leq k\leq K}
输入所有 x_i,y_j,计算所有 f_k. 这里 t\in F^{N\times M\times K} 是一个三阶张量,且 t 为一个常量.
我们记 C(t) 为一个双线性计算问题所需要的最少乘法次数.
例如,对于两个 K\times M 和 M\times N 的矩阵乘法 t\in F^{KM\times MN\times KN},有
t_{(i,j'),(j,k'),(i',k)}=\begin{cases}
1&\text{if}\ i=i',j=j',k=k'\\
0&\text{otherwise}.
\end{cases}
我们以后记这个张量为 \langle K,M,N\rangle.
此外,多项式乘法也可以被描述为一个双线性计算问题. 这里不再阐述.
设有三线性多项式
\sum_{i=1}^N\sum_{j=1}^M\sum_{k=1}^Kt_{i,j,k}x_iy_jz_k
三线性多项式与双线性计算问题一一对应,且三线性多项式中 z_k 的系数即双线性计算问题中的 f_k.
容易发现矩阵乘法的三线性多项式为
\sum_{i=1}^n\sum_{j=1}^m\sum_{k=1}^lx_{(i,j)}y_{(j,k)}z_{(i,k)}
表示为图像的形式即
对于一个双线性计算问题,我们现在需要做的事情就是,构造 r 个如下的乘积式
P_{\lambda}=\left(\sum_{i}u_{\lambda,i}x_i\right)\left(\sum_{i}v_{\lambda,i}y_i\right)\pod{1\leq\lambda\leq r}
使得每个 f_k 都可以被 P 线性表示,即存在 w_{\lambda}\in F^r,满足 f_k=\sum_{\lambda=1}^rw_{\lambda,k}P_{\lambda}.
换句话说,我们要找到矩阵 u,v,w,满足
\sum_{i,j,k}t_{i,j,k}x_iy_jz_k=\sum_{\lambda=1}^r(\sum_{i=1}^Nu_{\lambda,i}x_i)(\sum_{j=1}^Mv_{\lambda,j}y_j)(\sum_{k=1}^Kw_{\lambda,k}z_k)
即 t=\sum_{\lambda=1}^ru_{\lambda}\otimes v_{\lambda}\otimes w_{\lambda}.
这样,x,y,z 就有了相等的地位,就是说输入和输出是等价的了.
我们记 R(t) 为满足上述条件的最小的 r,即这个张量的秩. 其相当于和为 t 的秩为 1 的张量的最小个数.
另外,可以证明 C(t)\leq R(t)\leq 2C(t),所以研究 R(t) 很大程度上就在研究矩阵乘法的效率.
对于两个张量 t\in F^{N\times M\times K},t'\in F^{N'\times M'\times K'},定义其张量和 t\oplus t'\in F^{(N+N')\times (M+M')\times(K+K')} 为
(t\oplus t')_{i,j,k}=\begin{cases}
t_{i,j,k}&i\leq N,j\leq M,k\leq K\\
t'_{i-N,j-M,k-K}&i>N,j>M,k>K\\
0&o.w.
\end{cases}
形象地说,就是把 t 放到 t\oplus t' 的左下角,把 t' 放到 t\oplus t' 的右上角.
同样,定义其张量积 t\otimes t'\in F^{NN'\times MM'\times KK'} 为
(t\otimes t')_{ii',jj',kk'}=t_{i,j,k}t'_{i',j',k'}
张量满足以下两个性质:
-
-
证明:设 $t=\sum_{l=1}^ra_l\otimes b_l\otimes c_l,t'=\sum_{l=1}^sa'_l\otimes b'_l\otimes c'_l$.
设 $\hat{a}_{l,l'}=a_l\otimes a'_{l'}$,$\hat{b},\hat{c}$ 同理,则容易验证 $t\otimes t'=\sum_{l=1}^r\sum_{l’=1}^s\hat{a}_{l,l’}\otimes\hat{b}_{l,l’}\otimes\hat{c}_{l,l’}$.
对于矩阵乘法对应的张量 \langle K,M,N\rangle,显然其与 \langle M,N,K\rangle 同构,因为其每个非 0 位置都可以看做一个三角形. 因此有 R(\langle K,M,N\rangle)=R(\langle M,N,K\rangle)=R(\langle N,K,M\rangle).
另外,显然 \langle K,M,N\rangle\otimes\langle K',M',N'\rangle=\langle KK',MM',NN'\rangle,其代表对矩阵进行 K\times M\times N 的分块后再进行 K'\times M'\times N' 的分块.
接着,我们记
\omega=\inf_{n>1}\frac{\log R(\langle n,n,n\rangle)}{\log n}
根据主定理,若我们每次进行 n\times n\times n 的分块,则可以做到 O(n^{\frac{\log R(\langle n,n,n\rangle)}{\log n}+\epsilon}) 的时间复杂度. 因此矩阵乘法存在 O(n^{\omega+\epsilon}) 的时间复杂度.
那如果我们每次进行 K\times M\times N 的分块呢?
事实上,若 R(\langle K,M,N\rangle)\leq r,则 \omega\leq\dfrac{3\log r}{\log(KMN)}.
证明:我们先将其规划成方阵. 设 T=KMN,根据前文的观察有
\langle T,T,T\rangle=\langle K,M,N\rangle\otimes\langle M,N,K\rangle\otimes\langle N,K,M\rangle
因此根据张量的性质,有
R(\langle T,T,T\rangle)\leq R(\langle K,M,N\rangle)\cdot R(\langle M,N,K\rangle)\cdot R(\langle N,K,M\rangle)\leq r^3
故 \omega\leq\dfrac{\log R(\langle T,T,T\rangle)}{\log T}\leq\dfrac{\log r^3}{\log T}=\dfrac{3\log r}{\log(KMN)}.
Strassen 和 Pan 的构造
最初,Strassen 的构造无非指出 R(\langle2,2,2\rangle)\leq 7,因此 \omega\leq \log_27\approx2.8074.
然后,Pan (1980) 构造出了 R(\langle70,70,70\rangle)\leq 143640,这导出了 \omega\leq2.796.
到目前为止,远超我们理解的奇迹还没有出现.
Border rank
事情的转机来源于一个奇怪的现象. 假设我们有一系列连续的矩阵 A_j,其满足 j\to \infty,A_j\to A. 假设对于任意 j,都有 R(A_j)\leq r,则可以证明 R(A)\leq r.
为什么这样是对的呢?由于矩阵是连续的,因此每个 r+1 阶子矩阵的行列式也是连续的,故由于 j\to\infty 时,A_j 的每个 r+1 阶子矩阵的行列式都趋近于 0,因此 A 的每个 r+1 阶子矩阵的行列式为 0.
但奇怪的是,这对于三阶张量是不成立的. 例如我们现在有一个三线性多项式
a_0b_0c_0+a_1b_0c_1+a_0b_1c_1
设 t 为其代表的张量. 显然 R(t)=3.
同样地,我们让一系列三阶张量趋近于它. 具体地,设
t(\epsilon)=(1,\epsilon)\otimes(1,\epsilon)\otimes (0,1/\epsilon)+(1,0)\otimes(1,0)\otimes (1,-1/\epsilon)
容易验证当 \epsilon\to 0 时,t(\epsilon)\to t. 但奇怪的是,从定义中可以看到,R(t(\epsilon))=2.
我们考虑多项式环 F[\epsilon],其包含所有形如 \sum_{i=0}^ma_i\epsilon^i 的多项式,其中 m\geq 0,a_i\in F\pod{1\leq i\leq m}.
我们定义 R_h(t) 为最小的 r,使得有一组解 u_{\lambda}\in F[\epsilon]^K,v_{\lambda}\in F[\epsilon]^M,w_{\lambda}\in F[\epsilon]^N\pod{1\leq \lambda\leq r},满足
\sum_{\lambda=1}^ru_{\lambda}\otimes v_{\lambda}\otimes w_{\lambda}=\epsilon^ht+O(\epsilon^{h+1})
其中 t\in F^{K\times M\times N},O(\epsilon^{h+1}) 为满足 \epsilon 的指数均大于 h 的张量.
也就是说,可以认为,当 \epsilon\to 0 时,\frac1{\epsilon^h}\sum_{\lambda=1}^ru_{\lambda}\otimes v_{\lambda}\otimes w_{\lambda}\to t.
定义 border rank \underline{R}(t)=\min_{h}R_h(t).
显然,R(t)=R_0(t)\geq R_1(t)\geq\cdots\geq\underline{R}(t).
接着,不难证明:
-
-
进一步地,若 R_h(t)\leq r,则 R(t)\leq \binom{h+2}2\cdot r.
证明:设 \sum_{l=1}^ru_l\otimes v_l\otimes w_l=\epsilon^h\cdot t+O(\epsilon^{h+1}),由于其每个值都是关于 \epsilon 的 h 次多项式,因此设 u_{l}=\sum_{i=0}^h u'_{l,i}\epsilon^i,v,w 同理,则有 t=\sum_{i+j+k=h}\sum_{l=0}^ru'_{l,i}\otimes v'_{l,j}\otimes w'_{l,k}. 根据插板法,容易得到满足 i+j+k=h 的三元组个数为 \binom{h+2}2. 故 R(t)\leq\binom{h+2}2\cdot r.
因此可以导出如下结果:
\omega\leq\frac{3\log\underline{R}(\langle K,M,N\rangle)}{\log(KMN)}
证明:和普通 rank 一样,我们先将其规划成方阵. 设 T=KMN,根据矩阵张量的相关性质有
\langle T,T,T\rangle=\langle K,M,N\rangle\otimes\langle M,N,K\rangle\otimes\langle N,K,M\rangle
设 h 满足 R_h(\langle K,M,N\rangle)=\underline{R}(\langle K,M,N\rangle). 根据 border rank 的性质,有
R_{3h}(\langle T,T,T\rangle)\leq r^3
将其张量 s 次,得到
R_{3hs}(\langle T,T,T\rangle)\leq r^{3s}
根据之前的结论,有 R(\langle T,T,T\rangle)\leq c_{3hs}\cdot r^{3s},其中 c_{3hs}=\binom{3hs+2}2. 因此
\omega\leq\frac{\log(c_{3hs}r^{3s})}{\log(T^s)}=\frac{3\log r}{\log(KMN)}+O(\frac{\log s}s)
因此当 s\to \infty 时,其趋近于 \frac{3\log r}{\log(KMN)}.
因此,我们不需要控制 rank,只需控制 border rank(其实也是间接控制 rank),就能得到更快的矩阵乘法了.
Bini, Capovani, Romani (1979) 得到了 \underline{R}(\langle 2,2,3\rangle)\leq 10,进而得到 \omega\leq2.7799.
具体地,我们考虑这样一个矩阵乘法:
\begin{bmatrix}
x_{11}&x_{12}\\x_{21}&x_{22}
\end{bmatrix}
\begin{bmatrix}
y_{11}&y_{12}\\y_{21}&y_{22}
\end{bmatrix}=
\begin{bmatrix}
z_{11}&z_{12}\\z_{21}&\cancel{z_{22}}
\end{bmatrix}
其中,z_{22} 无需计算. 设其张量为 t. 显然将 t 复制两份就是 \langle2,2,3\rangle.
可以证明,R(t)=6. 但是,我们可以构造出 R_1(t)\leq 5. 设
\begin{aligned}
P_1&=(x_{12}+\epsilon x_{22})y_{22}\\
P_2&=x_{11}(y_{11}+\epsilon y_{12})\\
P_3&=x_{12}(y_{12}+y_{21}+\epsilon y_{22})\\
P_4&=(x_{11}+x_{12}+\epsilon x_{21})y_{11}\\
P_5&=(x_{12}+\epsilon x_{21})(y_{11}+\epsilon y_{22})
\end{aligned}
容易验证
\begin{aligned}
\epsilon P_1+\epsilon P_2&=\epsilon z_{11}+O(\epsilon^2)\\
P_2-P_4+P_5&=\epsilon z_{12}+O(\epsilon^2)\\
P_1-P_3+P_5&=\epsilon z_{21}+O(\epsilon^2)\\
\end{aligned}
故 \underline{R}(t)\leq 5. 因此 \underline{R}(\langle2,2,3\rangle)\leq 2\underline{R}(t)\leq 10. 因此 \omega\leq\frac{3\log 10}{\log 12}\leq2.7799.
更多?
本文其实还没有介绍 Schönhage 的 \omega\leq 2.522,也没有介绍 CW 的 \omega\leq2.404 或 \omega\leq 2.3755,更没有介绍目前的最优结果 laser 法. 有兴趣的同学可以看 EI 的文章(当然直接看原论文可能更好理解). 不过建议大家不要像我这样浪费时间学些没用的东西.