加载中…
个人资料
  • 博客等级:
  • 博客积分:
  • 博客访问:
  • 关注人气:
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
正文 字体大小:

矩阵乘法的Strassen算法

(2019-11-25 21:33:28)
分类: 生活

目录

一、矩阵相乘

二、矩阵相乘的朴素算法

三、Strassen其人及其算法

四、 Strassen的理想

五、 Strassen的一种代码实现

六、 参考

一、矩阵相乘

矩阵相乘,是指一个m*n的矩阵和一个n*k的矩阵相乘而得到一个m*k矩阵的一种运算。矩阵相乘可用于线性变换和矩阵分解[1]

二、矩阵相乘的朴素算法

矩阵相乘的一般步骤是第一个矩阵的第i*第二个矩阵的第j列,得到第三个矩阵的第i行第j列元素

矩阵乘法的Strassen算法

因此其素朴算法是用三层循环的出计算结果

 for i 1 to n

     do for j 1 to n

         do c[i][j] 0

             for k 1 to n

                 do c[i][j] c[i][j] + a[i][k] b[k][j]

显然其时间复杂度为O(n^3)

三、Strassen其人及其算法

矩阵中的元素排列整齐,很容易分块,因此采用分治思想,将其划分为块,对各块进行相乘,分而治之,这样就会降低乘法运算的规模。

我们可以把一个n*n的矩阵划分为4n/2*n/2的子矩阵进行运算。

矩阵乘法的Strassen算法

这样递归进行预算,其形式如下,其中A11-A22代表a-d,B11-B22代表e-h:[5]

矩阵乘法的Strassen算法

由于划分的规模缩小为n/2,总共划分成了8块,且各块相加的时间复杂度总和为O(n^2),故其时间复杂度可以近似表示为T(n)=8T(n/2)+O(n^2)

根据用主方法(the master method)求解递归式的方法O(n^2),因此此分治方法的时间复杂度仍为O(n^3)

Volker Strassen是一位出生于1936年的德国数学家,他因为在概率论上的工作而广为人知,但是在计算机科学和算法领域,他却因为矩阵相乘算法而被大部分人认识,这个算法目前仍然是比通用矩阵相乘算法性能好的主要算法之一。[2]

Strassen1969年第一次发表关于这个算法的文章,并证明了复杂度为n^3的算法并不是最优算法。[2]

他做了一个巧妙的组合计算:

先计算如何组合

矩阵乘法的Strassen算法

再计算如下

矩阵乘法的Strassen算法

最后得出

矩阵乘法的Strassen算法

这时中间的分治计算只需P1-P7这七个乘法,因此时间复杂度公式也随之而变为:

T(n)=7T(n/2)+O(n^2)

同样据用主方法求解递归式的方法O(n^2)<</span>O(n^2.807),因此此分治方法的时间复杂度约为O(n^2.807)

虽然,Strassen给出的解决方案只一点点,但是,他的贡献却是相当巨大的,就是因为这导致了矩阵相乘领域更多的研究,产生了更快的算法,比如复杂度为O(n^2.3737)Coppersmith-Winograd算法。

四、Strassen的理想

对于研究时间复杂度理论而言,Strassen算法贡献巨大,这个算法鼓励我们要朝着完美一步步接近,这是非常值得肯定的,理想一定要有。

然而,对于Strassen算法的实际测试我们也要进行关注,压力实测来源于网络:[3]

矩阵乘法的Strassen算法

数据取600位上界,即超过10分钟跳出。可以看到使用Strassen算法时,耗时不但没有减少,反而剧烈增多,在n=700时计算时间就无法忍受。

造成如此结果的原因根据网上查阅资料,现罗列如下:[4]

  1)采用Strassen算法作递归运算,需要创建大量的动态二维数组,其中分配堆内存空间将占用大量计算时间,从而掩盖了Strassen算法的优势

  2)于是对Strassen算法做出改进,设定一个界限。当n<</font>界限时,使用普通法计算矩阵,而不继续分治递归。需要合理设置界限,不同环境(硬件配置)下界限不同

  3)矩阵乘法一般意义上还是选择的是朴素的方法,只有当矩阵变稠密,而且矩阵的阶数很大时,才会考虑使用Strassen算法。

改进策略为:设定一个界限。当n<界限时,使用普通法计算矩阵,而不继续分治递归。

改进后算法优势明显,就算时间大幅下降。之后,针对不同大小的界限进行试验。在初步试验中发现,当数据规模小于1000时,下界S法的差别不大,规模大于1000以后,n取值越大,消耗时间下降。最优的界限值在32128之间。

五、Strassen的一种代码实现[6]

#include  

#define  N  4  

//matrix + matrix  

void plus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )  {  

   int i, j;  

   for( i = 0; i < N / 2; i++ )  

   {  

       for( j = 0; j < N / 2; j++ )  

       {  

           t[i][j] = r[i][j] + s[i][j];  

       }  

   }  

}  

//matrix - matrix  

void minus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )  {  

   int i, j;  

   for( i = 0; i < N / 2; i++ )  

   {  

       for( j = 0; j < N / 2; j++ )  

       {  

           t[i][j] = r[i][j] - s[i][j];  

       }  

   }  

}  

//matrix * matrix  

void mul( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] ) {  

   int i, j, k;  

   for( i = 0; i < N / 2; i++ )  

   {  

       for( j = 0; j < N / 2; j++ )  

       {  

           t[i][j] = 0;  

           for( k = 0; k < N / 2; k++ )  

           {  

               t[i][j] += r[i][k] * s[k][j];  

           }  

       }  

   }  

}  

int main() {  

   int i, j, k;  

   int mat[N][N];  

   int m1[N][N];  

   int m2[N][N];  

   int a[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2];  

   int e[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2];  

   int p1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2];  

   int p5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2];  

   int r[N/2][N/2], s[N/2][N/2], t[N/2][N/2], u[N/2][N/2], t1[N/2][N/2], t2[N/2][N/2];  

 

 

   printf("\nInput the first matrix...:\n");  

   for( i = 0; i < N; i++ )  

   {  

       for( j = 0; j < N; j++ )  

       {  

           scanf("%d", &m1[i][j]);  

       }  

   }  

 

   printf("\nInput the second matrix...:\n");  

   for( i = 0; i < N; i++ )  

   {  

       for( j = 0; j < N; j++ )  

       {  

           scanf("%d", &m2[i][j]);  

       }  

   }  

 

   // a b c d e f g h  

   for( i = 0; i < N / 2; i++ )  

   {  

       for( j = 0; j < N / 2; j++ )  

       {  

           a[i][j] = m1[i][j];  

           b[i][j] = m1[i][j + N / 2];  

           c[i][j] = m1[i + N / 2][j];  

           d[i][j] = m1[i + N / 2][j + N / 2];  

           e[i][j] = m2[i][j];  

           f[i][j] = m2[i][j + N / 2];  

           g[i][j] = m2[i + N / 2][j];  

           h[i][j] = m2[i + N / 2][j + N / 2];  

       }  

   }  

     

   //p1  

   minus( r, f, h );  

   mul( p1, a, r );   

 

   //p2  

   plus( r, a, b );  

   mul( p2, r, h );  

 

   //p3  

   plus( r, c, d );  

   mul( p3, r, e );  

 

   //p4  

   minus( r, g, e );  

   mul( p4, d, r );  

 

   //p5  

   plus( r, a, d );  

   plus( s, e, f );  

   mul( p5, r, s );  

 

   //p6  

   minus( r, b, d );  

   plus( s, g, h );  

   mul( p6, r, s );  

 

   //p7  

   minus( r, a, c );  

   plus( s, e, f );  

   mul( p7, r, s );  

 

   //r = p5 + p4 - p2 + p6  

   plus( t1, p5, p4 );  

   minus( t2, t1, p2 );  

   plus( r, t2, p6 );  

 

   //s = p1 + p2  

   plus( s, p1, p2 );  

 

   //t = p3 + p4  

   plus( t, p3, p4 );  

     

   //u = p5 + p1 - p3 - p7 = p5 + p1 - ( p3 + p7 )  

   plus( t1, p5, p1 );  

   plus( t2, p3, p7 );  

   minus( u, t1, t2 );  

 

   for( i = 0; i < N / 2; i++ )  

   {  

       for( j = 0; j < N / 2; j++ )  

       {  

           mat[i][j] = r[i][j];  

           mat[i][j + N / 2] = s[i][j];  

           mat[i + N / 2][j] = t[i][j];  

           mat[i + N / 2][j + N / 2] = u[i][j];  

       }  

   }  

 

   printf("\n下面是strassen算法处理结果:\n");  

   for( i = 0; i < N; i++ )  

   {  

       for( j = 0; j < N; j++ )  

       {  

           printf("%d ", mat[i][j]);  

       }  

       printf("\n");  

   }  

 

   //下面是朴素算法处理  

   printf("\n下面是朴素算法处理结果:\n");  

   for( i = 0; i < N; i++ )  

   {  

       for( j = 0; j < N; j++ )  

       {  

           mat[i][j] = 0;  

           for( k = 0; k < N; k++ )  

           {  

               mat[i][j] += m1[i][j] * m2[i][j];  

           }  

       }  

   }  

 

   for( i = 0; i < N; i++ )  

   {  

       for( j = 0; j < N; j++ )  

       {  

           printf("%d ", mat[i][j]);  

       }  

       printf("\n");  

   }   

   return 0;  

}  

六、参考

1、https://www.zhihu.com/question/21351965?sort=created

2、https://yq.aliyun.com/articles/3591

3、https://blog.csdn.net/handawnc/article/details/7987107

4、https://www.cnblogs.com/zhoutaotao/p/3963048.html

5、《算法导论》第三版

6、https://www.2cto.com/kf/201303/197291.html

0

阅读 收藏 喜欢 打印举报/Report
  

新浪BLOG意见反馈留言板 欢迎批评指正

新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 产品答疑

新浪公司 版权所有