-2

As part of a program, I need to multiply two 2D matrixes together. These matrixes are a part of the Matrix class that was created. The code I have at the moment works well, but I was wondering if there is a more efficient way of multiplying these matrices together.

public Matrix multiply(Matrix matrix) {
    //returns 2D array of Matrix matrix object
    int[][] userMatrix = matrix.getMatrix();
    //int [][] for the multiplied matrix
    int[][] multiplied = new int[length][length];
    int[] tempA = new int[length];
    int[] tempB = new int[length];

    int sum = 0;
    for (int row = 0; row < length; row++) {
        for (int col = 0; col < length; col++) {
            tempA[col] = arrayObject[row][col];
        }
        for (int j = 0; j < length; j++) {
            for (int i = 0; i < length; i++) {
                tempB[i] = userMatrix[i][j];
            }
            for (int k = 0; k < length; k++) {
                sum += tempA[k] * tempB[k];
            }
            multiplied[row][j] = sum;
            sum = 0;
        }
    }

    //converts the int[][] to a Matrix object
    Matrix returnMatrix = new Matrix(multiplied, multiplied.length);

    return returnMatrix;
}
Trisha T
  • 11
  • 4

4 Answers4

2

Use a Strassen-like scheme for the multiplication. In essence, you break your matrices into four submatrices and calculate some intermediate values, then calculate the solution from these smaller intermediate values.

Schema:

matrix split

Now, instead of calculating

C_11=A_11·B_11+A_12·B_21
C_12=A_11·B_12+A_12·B_22
C_21=A_21·B_11+A_22·B_21
C_22=A_21·B_12+A_22·B_22

You calculate the intermediaries

M_1 = (A_11+A_22)·(B_11+B_22)
M_2 = (A_21+A_22)·B_11
M_3 = A_11·(B_12-B_22)
M_4 = A_22·(B_21-B_11)
M_5 = (A_11+A_12)·B_22
M_6 = (A_21-A_11)·(B_11+B_12)
M_7 = (A_12-A_22)·(B_21+B_22)

and get the solutions as

C_11 = M_1+M_4-M_5+M_7
C_12 = M_3+M_5
C_21 = M_2+M_4
C_22 = M_1-M_2+M_3+M_6

Keep doing this recursively and you will need only O(n^log_2(7)) ~=O(n^2.807) multiplications (plus some effort for additions and subtractions) instead of the classical O(n^3) variant you use. For practical implementations you want to experiment until you find a good cutoff point to switch to the classical variant.

As for code. Try the following one. (Note, I do not claim it works, it was just the only one I found with a proper license attached (GPL 3.0))

Also, a big warning: pretty much all code I found implicitly assume the matrices to be powers of two for the splitting step to be seamless until you are at the base case. You might need to add some logic to handle other splits (or revert to the base implementation).

Generally you should just use a library for this and save yourself the pain of implementation+testing.

import java.io.BufferedReader;
import java.io.FileReader;

public class Strassen {

    static final int STRASSEN_MULT = 0;
    static final int STANDARD_MULT = 1;
    static final int CROSSOVER = 64;

    static int MULT_MODE = STRASSEN_MULT;

    static boolean DEBUGGING = false;
    static boolean CROSSOVER_TEST = false;

    /*
        Formula:
        [[ A B ]     [[ E F ]      [[AE + BG   AF + BH]
         [ C D ]]  *  [ G H ]]  =   [CE + DG   CF + DH]]
     */

    public static int[][] strassenWithCrossover(int[][] X, int[][] Y, int crossover) {
        int[][] ret = new int[X.length][X.length];
        if (X.length <= crossover) {
            ret = standardMult(X, Y);
            return ret;
        }

        int n = X.length;

        int[][] A = getSubMatrix(X, 0, 0);
        int[][] D = getSubMatrix(X, n / 2, n / 2);

        int[][] E = getSubMatrix(Y, 0, 0);
        int[][] H = getSubMatrix(Y, n / 2, n / 2);

        int[][] P1 = strassenWithCrossover(A, subtract(Y, 0, n / 2, Y, n / 2, n / 2), crossover);
        int[][] P2 = strassenWithCrossover(add(X, 0, 0, X, 0, n / 2), H, crossover);
        int[][] P3 = strassenWithCrossover(add(X, n / 2, 0, X, n / 2, n / 2), E, crossover);
        int[][] P4 = strassenWithCrossover(D, subtract(Y, n / 2, 0, Y, 0, 0), crossover);
        int[][] P5 = strassenWithCrossover(add(X, 0, 0, X, n / 2, n / 2), add(Y, 0, 0, Y, n / 2, n / 2), crossover);
        int[][] P6 = strassenWithCrossover(subtract(X, 0, n / 2, X, n / 2, n / 2), add(Y, n / 2, 0, Y, n / 2, n / 2), crossover);
        int[][] P7 = strassenWithCrossover(subtract(X, 0, 0, X, n / 2, 0), add(Y, 0, 0, Y, 0, n / 2), crossover);

        int[][] AE_plus_BG = subtract(add(P5, P4), subtract(P2, P6));
        int[][] AF_plus_BH = add(P1, P2);
        int[][] CE_plus_DG = add(P3, P4);
        int[][] CF_plus_DH = subtract(add(P5, P1), add(P3, P7));

        assignSubMatrix(ret, 0, 0, AE_plus_BG);
        assignSubMatrix(ret, 0, n / 2, AF_plus_BH);
        assignSubMatrix(ret, n / 2, 0, CE_plus_DG);
        assignSubMatrix(ret, n / 2, n / 2, CF_plus_DH);

        return ret;

    }

    private static int[][] getSubMatrix(int[][] matrix, int rowStart, int colStart) {

        int[][] ret = new int[matrix.length / 2][matrix.length / 2];
        int i = rowStart;
        for (int row = 0; row < matrix.length / 2; row++) {
            int j = colStart;
            for (int col = 0; col < (matrix.length / 2); col++) {
                ret[row][col] = matrix[i][j];
                j++;
            }
            i++;
        }
        return ret;
    }

    private static void assignSubMatrix(int[][] matrix, int rowStart, int colStart, int[][] sub) {

        int i = rowStart;
        int j;
        for (int row = 0; row < matrix.length / 2; row++) {
            j = colStart;
            for (int col = 0; col < matrix.length / 2; col++) {
                matrix[i][j] = sub[row][col];
                j++;
            }
            i++;
        }
    }

    private static int[][] add(int[][] X, int[][] Y) {

        int[][] ret = new int[X.length][X.length];
        for (int row = 0; row < ret.length; row++) {
            for (int col = 0; col < ret.length; col++) {
                ret[row][col] = X[row][col] + Y[row][col];
            }
        }

        return ret;
    }

    private static int[][] add(int[][] X, int X_row_start, int X_col_start, int[][] Y, int Y_row_start, int Y_col_start) {

        int length = X.length / 2;
        int[][] ret = new int[length][length];
        for (int row = 0; row < length; row++) {
            for (int col = 0; col < length; col++) {
                ret[row][col] = X[X_row_start + row][X_col_start + col] + Y[Y_row_start + row][Y_col_start + col];
            }
        }

        return ret;
    }

    private static int[][] subtract(int[][] X, int[][] Y) {

        int[][] ret = new int[X.length][X.length];
        for (int row = 0; row < ret.length; row++) {
            for (int col = 0; col < ret.length; col++) {
                ret[row][col] = X[row][col] - Y[row][col];
            }
        }

        return ret;

    }

    private static int[][] subtract(int[][] X, int X_row_start, int X_col_start, int[][] Y, int Y_row_start, int Y_col_start) {

        int length = X.length / 2;
        int[][] ret = new int[length][length];
        for (int row = 0; row < length; row++) {
            for (int col = 0; col < length; col++) {
                ret[row][col] = X[X_row_start + row][X_col_start + col] - Y[Y_row_start + row][Y_col_start + col];
            }
        }

        return ret;

    }

    public static void main(String[] args) {

        if (args.length != 3) {
            System.out.println("Usage: ./strassen 0 dimension inputfile");
            System.exit(1);
        }

        int flag = Integer.parseInt(args[0]);
        int dimension = Integer.parseInt(args[1]);
        String inputfile = new String(args[2]);

        if (flag == 1) {
            DEBUGGING = true;
        } else if (flag == 2) {
            CROSSOVER_TEST = true;
        }

        Strassen me = new Strassen();

        if (CROSSOVER_TEST) {
            for (int i = 1 << 7; i < 1 << 16; i *= 2) {
                me.run(i, inputfile, MULT_MODE);
            }

        } else {
            me.run(dimension, inputfile, MULT_MODE);
        }

    }


    public void run(int dimension, String inputfile, int mode) {

        long startTime;

        int[][] X = new int[dimension][dimension];
        int[][] Y = new int[dimension][dimension];

        int[] elements = {0, 1, 2, 0, 2, 1, 1, 0, 2, 1, 2, 0, 2, 1, 0, 2, 0, 1};
        int pos = 0;

        try {
            BufferedReader br = new BufferedReader(new FileReader(inputfile));

            for (int i = 0; i < dimension; i++) {
                for (int j = 0; j < dimension; j++) {
                    if (CROSSOVER_TEST) {
                        X[i][j] = elements[pos++];
                        pos %= elements.length;

                    } else {
                        X[i][j] = Integer.parseInt(br.readLine());

                    }
                }
            }
            for (int k = 0; k < dimension; k++) {
                for (int l = 0; l < dimension; l++) {
                    if (CROSSOVER_TEST) {
                        Y[k][l] = elements[pos++];
                        pos %= elements.length;

                    } else {
                        Y[k][l] = Integer.parseInt(br.readLine());

                    }
                }
            }

            br.close();

        } catch (Exception e) {
            System.err.println("Caught Exception: " + e.getMessage());
        }

        if (DEBUGGING) {
            System.out.println("\n##### Reading Matrices X and Y from file ######\n");
            printMatrix(X,"X");
            printMatrix(Y,"Y");
        }

        if (mode == STANDARD_MULT) {
            int[][] Z = standardMult(X, Y);

            if (DEBUGGING) {
                System.out.println("Standard Product");
                printMatrix(Z, "Z");
            }

        } else if (mode == STRASSEN_MULT && CROSSOVER_TEST) {

            for (int crossover = 2; crossover <= dimension; crossover *= 2) {
                startTime = System.currentTimeMillis();

                int[][] paddedX = pad(X);
                int[][] paddedY = pad(Y);

                int[][] Z = strassenWithCrossover(paddedX, paddedY, crossover);

                printTimes("Strassen Product", startTime, dimension, crossover);
            }

        } else if (mode == STRASSEN_MULT) {
            startTime = System.currentTimeMillis();

            int[][] paddedX = pad(X);
            int[][] paddedY = pad(Y);

            int[][] Z = strassenWithCrossover(paddedX, paddedY, CROSSOVER);

            int[][] ZTrimmed = trim(Z, dimension);

            if (DEBUGGING) {
                printTimes("Strassen Product", startTime, dimension, CROSSOVER);
                printMatrix(ZTrimmed, "Z");

            } else {
                printDiagonal(ZTrimmed);
            }
        }


    }

    private static void printTimes(String mode, long startTime, int dimension, int crossover) {
        System.out.println(mode + " Crossover = " + crossover);
        long time = System.currentTimeMillis() - startTime;
        System.out.printf("Finished Matrix Multiplication of %d dimensions in %d milliseconds, or %.2f minutes\n", dimension, time, ((double) time) / 60 / 1000);
        System.out.println();

    }

    private static int[][] pad(int[][] matrix) {

        int newDim = nextPowerOf2(matrix.length);
        if (newDim == matrix.length)
            return matrix;
        int[][] ret = new int[newDim][newDim];

        for (int row = 0; row < matrix.length; row++) {
            for (int col = 0; col < matrix.length; col++) {
                ret[row][col] = matrix[row][col];
            }
        }
        return ret;

    }

    private static int[][] trim(int[][] matrix, int dim) {

        int[][] ret = new int[dim][dim];
        for (int row = 0; row < dim; row++) {
            for (int col = 0; col < dim; col++) {
                ret[row][col] = matrix[row][col];
            }
        }

        return ret;
    }

    private static int nextPowerOf2(int length) {
        int exponent = (int) (Math.log(length) / Math.log(2));
        int reconstructed = (int) Math.pow(2, exponent);
        if (length != reconstructed) {
            return (int) Math.pow(2, exponent + 1);
        }
        return length;
    }

    // Standard Matrix Multiplication
    public static int[][] standardMult(int[][] A, int[][] B) {

        int dim = B.length;
        int[][] C = new int[B.length][B.length];
        for (int i = 0; i < dim; i++) {
            for (int j = 0; j < dim; j++) {
                for (int k = 0; k < dim; k++) {
                    if (DEBUGGING) {
                        System.out.println(C[i][k] + " += " + A[i][k] + " * " + B[k][j]);
                    }

                    C[i][j] += A[i][k] * B[k][j];
                }
            }
        }
        return C;
    }

    // Prints complete matrix
    public static void printMatrix(int[][] A) {
        int dim = A.length;

        for (int i = 0; i < dim; i++) {
            System.out.print(" [ ");

            for (int j = 0; j < dim; j++) {
                System.out.print(A[i][j] + " ");
            }
            System.out.println("]");
        }
        System.out.println();
    }

    // Prints complete matrix
    public static void printMatrix(int[][] A, String name) {
        System.out.println("Printing matrix " + name);
        printMatrix(A);
    }

    // Prints the list of values on the diagonal entries
    public static void printDiagonal(int[][] A) {
        for (int i = 0; i < A.length; i++) {
            System.out.println(A[i][i]);
        }
    }

}
Eisenknurr
  • 374
  • 1
  • 13
  • I think this only works for square matrices. https://en.wikipedia.org/wiki/Strassen_algorithm – brando f Mar 28 '23 at 21:57
  • That is correct. You have to embed your matrices into the next larger identity matrix (it looks like this [link](https://cs.stackexchange.com/a/98015/97716) then) - alternatively you can implement the regular 3 nested loop algorithm and use this algorithm as a sub function on power-of-2-sized sub matrices. Also, depending on the domain you might only need a specialized function to multiply two 3x3 or 4x4 matrices - for the first case the link has a reference in it. – Eisenknurr Mar 29 '23 at 17:21
0

You can use three nested IntStreams to multiply two matrices. Outer stream iterates over the rows of the first matrix and the inner stream iterates over the columns of the second matrix to build the resulting matrix. The innermost stream obtains the entries of the resulting matrix. Each entry is the sum of the products obtained by multiplying the i-th row of the first matrix and the j-th column of the second matrix.

Matrix multiplication

// dimensions
int m = 4;
int n = 5;
int p = 3;

// 'a' is an 'm×n' matrix
int[][] a = {
        {2, 4, -1, 7, 9},
        {-2, -6, 6, 4, 3},
        {3, 8, 2, 7, -5},
        {8, 5, 3, 2, -7}};

// 'b' is an 'n×p' matrix
int[][] b = {
        {3, -2, 1},
        {8, 0, -9},
        {2, 5, 7},
        {-1, 6, 9},
        {3, -3, 5}};
// 'c' is an 'm×p' matrix
int[][] c = IntStream.range(0, m)
        .mapToObj(i -> IntStream.range(0, p)
                .map(j -> IntStream.range(0, n)
                        // multiply the entries
                        // of the i-th row of 'a'
                        // and the j-th column of 'b'
                        .map(k -> a[i][k] * b[k][j])
                        // the sum of the products
                        .sum())
                .toArray())
        .toArray(int[][]::new);
// output
Arrays.stream(c).map(Arrays::toString).forEach(System.out::println);
[56, 6, 67]
[-37, 49, 145]
[55, 61, -17]
[47, 32, -33]

See also: Parallelized Matrix Multiplication

Community
  • 1
  • 1
0

Algorithm of the matrix multiplication:

// dimensions
int m = 3; // rows of 'a' matrix
int n = 2; // columns of 'a' matrix
           // and rows of 'b' matrix
int p = 4; // columns of 'b' matrix

// matrices 'a=m×n', 'b=n×p'
int[][] a = {{1, 2}, {3, 4}, {5, 6}},
        b = {{1, 2, 3, 4}, {5, 6, 7, 8}},
        // resulting matrix 'c=m×p'
        c = new int[m][p];
// iterate over the rows of the 'a' matrix
for (int i = 0; i < m; i++) {
    // iterate over the columns of the 'b' matrix
    for (int j = 0; j < p; j++) {
        // iterate over the columns of the 'a'
        // matrix, aka rows of the 'b' matrix
        for (int k = 0; k < n; k++) {
            // sum of the products of
            // the i-th row of 'a' and
            // the j-th column of 'b'
            c[i][j] += a[i][k] * b[k][j];
        }
    }
}
// output
for (int[] rowA : a) System.out.println(Arrays.toString(rowA));
//[1, 2]
//[3, 4]
//[5, 6]
for (int[] rowB : b) System.out.println(Arrays.toString(rowB));
//[1, 2, 3, 4]
//[5, 6, 7, 8]
for (int[] rowC : c) System.out.println(Arrays.toString(rowC));
//[11, 14, 17, 20]
//[23, 30, 37, 44]
//[35, 46, 57, 68]

See also: Multiplication of square matrices using 2D lists

0

You can either do it using traditional approach but you have to iterate for loop three times like this -

for () {
  for () {
    for () {
    }
  }
}

The complexity will be O(n³).

Or you can use the latest way by assuming the input as Stream and use Stream function on the matrixes

public static void main(String[] args) {
    double[][] m1 = {{4, 8}, {0, 2}, {1, 6}};
    double[][] m2 = {{5, 2}, {9, 4}};

    double[][] result = Arrays.stream(m1)
            .map(r -> IntStream.range(0, m2[0].length)
                    .mapToDouble(i -> IntStream.range(0, m2.length)
                            .mapToDouble(j -> r[j] * m2[j][i]).sum())
                    .toArray())
            .toArray(double[][]::new);

    System.out.println(Arrays.deepToString(result));
    // [[92.0, 40.0], [18.0, 8.0], [59.0, 26.0]]
}

These are just sample values, you can use your own implementation here.

Community
  • 1
  • 1
Vaibhav Atray
  • 208
  • 4
  • 14