目录
一、矩阵相乘
二、矩阵相乘的朴素算法
三、Strassen其人及其算法
四、
Strassen的理想
五、
附Strassen的一种代码实现
六、
参考
一、矩阵相乘
矩阵相乘,是指一个m*n的矩阵和一个n*k的矩阵相乘而得到一个m*k矩阵的一种运算。矩阵相乘可用于线性变换和矩阵分解[1]。
二、矩阵相乘的朴素算法
矩阵相乘的一般步骤是第一个矩阵的第i行*第二个矩阵的第j列,得到第三个矩阵的第i行第j列元素

因此其素朴算法是用三层循环的出计算结果
for i ← 1 to n
do
for j ← 1 to
n
do
c[i][j] ← 0
for
k ← 1 to
n
do
c[i][j] ← c[i][j] +
a[i][k]⋅ b[k][j]
显然其时间复杂度为O(n^3)
三、Strassen其人及其算法
矩阵中的元素排列整齐,很容易分块,因此采用分治思想,将其划分为块,对各块进行相乘,分而治之,这样就会降低乘法运算的规模。
我们可以把一个n*n的矩阵划分为4个n/2*n/2的子矩阵进行运算。

这样递归进行预算,其形式如下,其中A11-A22代表a-d,B11-B22代表e-h:[5]

由于划分的规模缩小为n/2,总共划分成了8块,且各块相加的时间复杂度总和为O(n^2),故其时间复杂度可以近似表示为T(n)=8T(n/2)+O(n^2)
根据用主方法(the master method)求解递归式的方法O(n^2),因此此分治方法的时间复杂度仍为O(n^3)。
Volker
Strassen是一位出生于1936年的德国数学家,他因为在概率论上的工作而广为人知,但是在计算机科学和算法领域,他却因为矩阵相乘算法而被大部分人认识,这个算法目前仍然是比通用矩阵相乘算法性能好的主要算法之一。[2]
Strassen在1969年第一次发表关于这个算法的文章,并证明了复杂度为n^3的算法并不是最优算法。[2]
他做了一个巧妙的组合计算:
先计算如何组合

再计算如下

最后得出

这时中间的分治计算只需P1-P7这七个乘法,因此时间复杂度公式也随之而变为:
T(n)=7T(n/2)+O(n^2)
同样据用主方法求解递归式的方法O(n^2)<</span>O(n^2.807),因此此分治方法的时间复杂度约为O(n^2.807)。
虽然,Strassen给出的解决方案只是好了一点点,但是,他的贡献却是相当巨大的,就是因为这导致了矩阵相乘领域更多的研究,产生了更快的算法,比如复杂度为O(n^2.3737)的Coppersmith-Winograd算法。
四、Strassen的理想
对于研究时间复杂度理论而言,Strassen算法贡献巨大,这个算法鼓励我们要朝着完美一步步接近,这是非常值得肯定的,理想一定要有。
然而,对于Strassen算法的实际测试我们也要进行关注,压力实测来源于网络:[3]

数据取600位上界,即超过10分钟跳出。可以看到使用Strassen算法时,耗时不但没有减少,反而剧烈增多,在n=700时计算时间就无法忍受。
造成如此结果的原因根据网上查阅资料,现罗列如下:[4]
1)采用Strassen算法作递归运算,需要创建大量的动态二维数组,其中分配堆内存空间将占用大量计算时间,从而掩盖了Strassen算法的优势
2)于是对Strassen算法做出改进,设定一个界限。当n<</font>界限时,使用普通法计算矩阵,而不继续分治递归。需要合理设置界限,不同环境(硬件配置)下界限不同
3)矩阵乘法一般意义上还是选择的是朴素的方法,只有当矩阵变稠密,而且矩阵的阶数很大时,才会考虑使用Strassen算法。
改进策略为:设定一个界限。当n<界限时,使用普通法计算矩阵,而不继续分治递归。
改进后算法优势明显,就算时间大幅下降。之后,针对不同大小的界限进行试验。在初步试验中发现,当数据规模小于1000时,下界S法的差别不大,规模大于1000以后,n取值越大,消耗时间下降。最优的界限值在32~128之间。
五、附Strassen的一种代码实现[6]
#include
#define
N 4
//matrix + matrix
void
plus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )
{
int
i, j;
for(
i = 0; i < N / 2; i++ )
{
for(
j = 0; j < N / 2; j++ )
{
t[i][j]
= r[i][j] + s[i][j];
}
}
}
//matrix - matrix
void
minus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )
{
int
i, j;
for(
i = 0; i < N / 2; i++ )
{
for(
j = 0; j < N / 2; j++ )
{
t[i][j]
= r[i][j] - s[i][j];
}
}
}
//matrix * matrix
void
mul( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] ) {
int
i, j, k;
for(
i = 0; i < N / 2; i++ )
{
for(
j = 0; j < N / 2; j++ )
{
t[i][j]
= 0;
for(
k = 0; k < N / 2; k++ )
{
t[i][j]
+= r[i][k] * s[k][j];
}
}
}
}
int
main() {
int
i, j, k;
int
mat[N][N];
int
m1[N][N];
int
m2[N][N];
int
a[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2];
int
e[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2];
int
p1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2];
int
p5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2];
int
r[N/2][N/2], s[N/2][N/2], t[N/2][N/2], u[N/2][N/2], t1[N/2][N/2],
t2[N/2][N/2];
printf("\nInput
the first matrix...:\n");
for(
i = 0; i < N; i++ )
{
for(
j = 0; j < N; j++ )
{
scanf("%d",
&m1[i][j]);
}
}
printf("\nInput
the second matrix...:\n");
for(
i = 0; i < N; i++ )
{
for(
j = 0; j < N; j++ )
{
scanf("%d",
&m2[i][j]);
}
}
//
a b c d e f g h
for(
i = 0; i < N / 2; i++ )
{
for(
j = 0; j < N / 2; j++ )
{
a[i][j]
= m1[i][j];
b[i][j]
= m1[i][j + N / 2];
c[i][j]
= m1[i + N / 2][j];
d[i][j]
= m1[i + N / 2][j + N / 2];
e[i][j]
= m2[i][j];
f[i][j]
= m2[i][j + N / 2];
g[i][j]
= m2[i + N / 2][j];
h[i][j]
= m2[i + N / 2][j + N / 2];
}
}
//p1
minus(
r, f, h );
mul(
p1, a, r );
//p2
plus(
r, a, b );
mul(
p2, r, h );
//p3
plus(
r, c, d );
mul(
p3, r, e );
//p4
minus(
r, g, e );
mul(
p4, d, r );
//p5
plus(
r, a, d );
plus(
s, e, f );
mul(
p5, r, s );
//p6
minus(
r, b, d );
plus(
s, g, h );
mul(
p6, r, s );
//p7
minus(
r, a, c );
plus(
s, e, f );
mul(
p7, r, s );
//r
= p5 + p4 - p2 + p6
plus(
t1, p5, p4 );
minus(
t2, t1, p2 );
plus(
r, t2, p6 );
//s
= p1 + p2
plus(
s, p1, p2 );
//t
= p3 + p4
plus(
t, p3, p4 );
//u
= p5 + p1 - p3 - p7 = p5 + p1 - ( p3 + p7 )
plus(
t1, p5, p1 );
plus(
t2, p3, p7 );
minus(
u, t1, t2 );
for(
i = 0; i < N / 2; i++ )
{
for(
j = 0; j < N / 2; j++ )
{
mat[i][j]
= r[i][j];
mat[i][j
+ N / 2] = s[i][j];
mat[i
+ N / 2][j] = t[i][j];
mat[i
+ N / 2][j + N / 2] = u[i][j];
}
}
printf("\n下面是strassen算法处理结果:\n");
for(
i = 0; i < N; i++ )
{
for(
j = 0; j < N; j++ )
{
printf("%d
", mat[i][j]);
}
printf("\n");
}
//下面是朴素算法处理
printf("\n下面是朴素算法处理结果:\n");
for(
i = 0; i < N; i++ )
{
for(
j = 0; j < N; j++ )
{
mat[i][j]
= 0;
for(
k = 0; k < N; k++ )
{
mat[i][j]
+= m1[i][j] * m2[i][j];
}
}
}
for(
i = 0; i < N; i++ )
{
for(
j = 0; j < N; j++ )
{
printf("%d
", mat[i][j]);
}
printf("\n");
}
return
0;
}
六、参考
1、https://www.zhihu.com/question/21351965?sort=created
2、https://yq.aliyun.com/articles/3591
3、https://blog.csdn.net/handawnc/article/details/7987107
4、https://www.cnblogs.com/zhoutaotao/p/3963048.html
5、《算法导论》第三版
6、https://www.2cto.com/kf/201303/197291.html
加载中,请稍候......