Strassen矩阵乘法问题(Java)

举报
WHYBIGDATA 发表于 2023/01/21 18:11:19 2023/01/21
【摘要】 Strassen矩阵乘法问题(Java)

Strassen矩阵乘法问题(Java)



1、前置介绍

矩阵乘法是线性代数中最常见的问题之一 ,它在数值计算中有广泛的应用。 设AB是2个nXn矩阵,
它们的乘积AB同样是一个nXn矩阵。 AB的乘积矩阵C中元素C[i][j]定义为:

C [ i ] [ j ] = k = 1 n A [ i ] [ k ] B [ k ] [ j ] C[i][j] = \sum_{k=1}^{n}A[i][k]B[k][j]

在这里插入图片描述

采用传统方法,时间复杂度为:O(n3)

因为按照上述的定义来计算A和 B的乘积矩阵c,则每计算C的一个元素C[i][j],需要做n次乘法运算和n-1次加法运算。 因此,得到矩阵C的n2 个元素所需的计算时间为 O(n3) 。

为解决计算计算效率问题,Strassen算法由此出现,该算法基本思想是分治,将计算2个n阶矩阵乘积所需的计算时间改进到0(nlog7) = 0(n2.81)

我们知道,C11=A11*B11+A12*B21

在这里插入图片描述

矩阵A和B的示意图如下:

在这里插入图片描述

传统方法:

在这里插入图片描述

2个n阶方阵的乘积转换为8个n/2 阶方阵的乘积和4个n/2阶方阵的加法。

由此可得:

C11 = A11B11 + A12B21

C12 = A11B12 + A12B22

C21 = A21B11 + A22B21

C22 = A21B12 + A22B22

分治法:

为了降低时间复杂度,必须减少乘法的次数。

使用与上例类似的技术,将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵。由此可将方程C=AB重写为:

在这里插入图片描述

2个n阶方阵的乘积转换为7个n/2 阶方阵的乘积和18个n/2阶方阵的加减法。

伪代码如下:

// 递归维度分半算法:
public void STRASSEN(n,A,B,C);
{  
if n=2 then MATRIX-MULTIPLY(ABC)
/ /结束循环,计算 两个2阶方阵的乘法         
else{
  将矩阵AB分块;
  STRASSEN(n/2,A11,B12-B22,M1);
  STRASSEN(n/2,A11+A12,B22,M2); 
  STRASSEN(n/2,A21+A22,B11,M3);
  STRASSEN(n/2,A22,B21-B11,M4);
  STRASSEN(n/2,A11+A22,B11+B22,M5);
  STRASSEN(n/2,A12-A22,B21+B22,M6);
  STRASSEN(n/2,A11-A21,B11+B12,M7);}
}                

算法导论伪代码:

在这里插入图片描述

3、代码实现

public class StrassenMatrixMultiply
{
    public static void main(String[] args)
    {
        int[] a = new int[]
        {
            1, 1, 1, 1,
            2, 2, 2, 2,
            3, 3, 3, 3,
            4, 4, 4, 4
        };

        int[] b = new int[]
        {
            1, 2, 3, 4,
            1, 2, 3, 4,
            1, 2, 3, 4,
            1, 2, 3, 4
        };

        int length = 4;

        int[] c = sMM(a, b, length);

        for(int i = 0; i < c.length; i++)
        {
            System.out.print(c[i] + " ");

            if((i + 1) % length == 0) //换行
                System.out.println();
        }
    }

    public static int[] sMM(int[] a, int[] b, int length) {
        if(length == 2) {
            return getResult(a, b);
        }
        else {
            int tlength = length / 2;
            // 把a数组分为四部分,进行分治递归
            int[] aa = new int[tlength * tlength];
            int[] ab = new int[tlength * tlength];
            int[] ac = new int[tlength * tlength];
            int[] ad = new int[tlength * tlength];
            // 把b数组分为四部分,进行分治递归
            int[] ba = new int[tlength * tlength];
            int[] bb = new int[tlength * tlength];
            int[] bc = new int[tlength * tlength];
            int[] bd = new int[tlength * tlength];

            // TODO 划分子矩阵
            for(int i = 0; i < length; i++) {
                for(int j = 0; j < length; j++) {
                    /*
                     * 划分矩阵:
                     * 例子:将 4 * 4 的矩阵,变为 2 * 2 的矩阵,
                     * 那么原矩阵左上、右上、左下、右下的四个元素分别归为新矩阵
                    */
                    if(i < tlength) {
                        if(j < tlength) {
                            aa[i * tlength + j] = a[i * length + j];
                            ba[i * tlength + j] = b[i * length + j];
                        } else {
                            ab[i * tlength + (j - tlength)] = a[i * length + j];
                            bb[i * tlength + (j - tlength)] = b[i * length + j];
                        }
                    } else {
                        if(j < tlength) {
                            //i 大于 tlength 时,需要减去 tlength,j同理
                            //因为 b,c,d三个子矩阵有对应了父矩阵的后半部分
                            ac[(i - tlength) * tlength + j] = a[i * length + j];
                            bc[(i - tlength) * tlength + j] = b[i * length + j];
                        } else {
                            ad[(i - tlength) * tlength + (j - tlength)] = a[i * length + j];
                            bd[(i - tlength) * tlength + (j - tlength)] = b[i * length + j];
                        }
                    }
                }
            }

            // TODO 分治递归
            int[] result = new int[length * length];

            // temp:4个临时矩阵
            int[] t1 = add(sMM(aa, ba, tlength), sMM(ab, bc, tlength));
            int[] t2 = add(sMM(aa, bb, tlength), sMM(ab, bd, tlength));
            int[] t3 = add(sMM(ac, ba, tlength), sMM(ad, bc, tlength));
            int[] t4 = add(sMM(ac, bb, tlength), sMM(ad, bd, tlength));

            // TODO 归并结果
            for(int i = 0; i < length; i++) {
                for(int j = 0; j < length; j++) {
                    if (i < tlength){
                        if(j < tlength) {
                            result[i * length + j] = t1[i * tlength + j];
                        } else {
                            result[i * length + j] = t2[i * tlength + (j - tlength)];
                        }
                    } else {
                        if(j < tlength) {
                            result[i * length + j] = t3[(i - tlength) * tlength + j];
                        } else {
                            result[i * length + j] = t4[(i - tlength) * tlength + (j - tlength)];
                        }
                    }
                }
            }
            return result;
        }
    }

    public static int[] getResult(int[] a, int[] b) {
        int p1 = a[0] * (b[1] - b[3]);
        int p2 = (a[0] + a[1]) * b[3];
        int p3 = (a[2] + a[3]) * b[0];
        int p4 = a[3] * (b[2] - b[0]);
        int p5 = (a[0] + a[3]) * (b[0] + b[3]);
        int p6 = (a[1] - a[3]) * (b[2] + b[3]);
        int p7 = (a[0] - a[2]) * (b[0] + b[1]);

        int c00 = p5 + p4 - p2 + p6;
        int c01 = p1 + p2;
        int c10 = p3 + p4;
        int c11 = p5 + p1 -p3 - p7;

        return new int[] {c00, c01, c10, c11};
    }

    public static int[] add(int[] a, int[] b) {
        int[] c = new int[a.length];
        for(int i = 0; i < a.length; i++) {
            c[i] = a[i] + b[i];
	    }
        return c;
    }

    // TODO 返回一个数是不是2的幂次方
    public static boolean adjust(int x) {
        return (x & (x - 1)) == 0;
    }
}

4、复杂度分析

传统方法和分治法的复杂度比较,如下图所示;

在这里插入图片描述

T ( n ) = { O ( 1 ) , n = 2 7 T ( n / 2 ) + O ( n 2 ) , n > 2 T(n) = \left\{ \begin{matrix} O(1), n = 2 \\ 7T(n/2) + O(n^2), n > 2\\ \end{matrix} \right.

T(n) = 0(nlog7 ) = 0(n2.81)

5、参考资料

  • 算法分析与设计(第四版)
  • 算法导论第三版
  • 博客园

结束!

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。