4

I copied strassen's algorithm from somewhere and then executed it. Here is the output

n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms

where strassen1 is a dynamic approach, strassen2 for cache and classical is the old matrix multiplication. This means that our old and easy classical one is the best. Is this true or i am wrong somewhere? Here's the code in Java.

import java.util.Random;

class TestIntMatrixMultiplication {

    public static void main (String...args) throws Exception {
        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
        final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
        final Random random = new Random(seed);

        int[][] a, b, c;

        a = new int[n][n];
        b = new int[n][n];
        c = new int[n][n];

        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                a[i][j] = random.nextInt(100);
                b[i][j] = random.nextInt(100);
            }
        }



        System.out.println("n = " + n);

        if (a.length < 64) {
            System.out.println("A");
            dumpMatrix(a);
            System.out.println("B");
            dumpMatrix(b);
            System.out.println("classic");
            Classical.mult(c, a, b);
            dumpMatrix(c);
            System.out.println("strassen");
            strassen2.mult(c, a, b);
            dumpMatrix(c);

            return;
        }

        for (int i = 0; i <3; ++i) {
            timeMultiplies1(a, b, c);
            if (n <= 256)
                timeMultiplies2( a, b, c);
            timeMultiplies3( a, b, c);
        }
    }

    static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        Classical.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("classical took " + (finish - start) + "ms");
    }
    static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen1.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen 1 took " + (finish - start) + "ms");
    }
    static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen2.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen2 took " + (finish - start) + "ms");
    }

    static void dumpMatrix (int[][] m) {
        for (int[] row : m) {
            System.out.print("[\t");
            for (int val : row) {
                System.out.print(val);
                System.out.print('\t');
            }
            System.out.println(']');
        }
    }
}

class strassen1{

    public String getName () {
        return "Strassen(dynamic)";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        return strassenMatrixMultiplication(a, b);
    }

    public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        if(n == 1) {
            result[0][0] = A[0][0] * B[0][0];
        } else {
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
            int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
            int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
            int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
            int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
            int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
            int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));

            int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
            int [][] C12 = addMatrices(P3, P5);
            int [][] C21 = addMatrices(P2, P4);
            int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);

            copySubArray(C11, result, 0 , 0);
            copySubArray(C12, result, 0 , n/2);
            copySubArray(C21, result, n/2, 0);
            copySubArray(C22, result, n/2, n/2);
        }

        return result;
    }

    public static int [][] addMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
        result[i][j] = A[i][j] + B[i][j];

        return result;
    }

    public static int [][] subtractMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                result[i][j] = A[i][j] - B[i][j];

        return result;
    }

    public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                child[i1][j1] = parent[i2][j2];
    }

    public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                parent[i2][j2] = child[i1][j1];
    }
}
class strassen2{

    public String getName () {
        return "Strassen(cached)";
    }

    static int [][] p1;
    static int [][] p2;
    static int [][] p3;
    static int [][] p4;
    static int [][] p5;
    static int [][] p6;
    static int [][] p7;
    static int [][] t0;
    static int [][] t1;

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        final int n = c.length;

        if (p1 == null || p1.length < n) {
            p1 = new int[n/2][n-1];
            p2 = new int[n/2][n-1];
            p3 = new int[n/2][n-1];
            p4 = new int[n/2][n-1];
            p5 = new int[n/2][n-1];
            p6 = new int[n/2][n-1];
            p7 = new int[n/2][n-1];
            t0 = new int[n/2][n-1];
            t1 = new int[n/2][n-1];
        }

        mult(c, a, b, 0, 0, n, 0);

        return c;
    }

    public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
        if(n == 1) {
            c[i0][j0] = a[i0][j0] * b[i0][j0];
        } else {
            final int nBy2 = n/2;

            final int i1 = i0 + nBy2;
            final int j1 = j0 + nBy2;

            // offset applied to 'p' j index so recursive calls don't overwrite data
            final int jp0 = offs;
            final int jp1 = nBy2 + offs;

            // P1 <- (A11 + A22)(B11 + B22)
            //  T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P2 <- (A21 + A22)B11
            //  T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0];
                    }
            }

            mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P3 <- A11(B12 - B22)
            //  T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
                }
            }

            mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P4 <- A22(B21 - B11)
            //  T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
                }
            }

            mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P5 <- (A11 + A12) B22
            //  T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j1];
                }
            }

            mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P6 <- (A21 - A11)(B11 - B12)
            //  T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
                }
            }

            mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P7 <- (A12 - A22)(B21 + B22)
            //  T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // combine
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    // C11 = P1 + P4 - P5 + P7;
                    c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
                    // C12 = P3 + P5;
                    c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
                    // C21 = P2 + P4;
                    c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
                    // C22 = P1 + P3 - P2 + P6;
                    c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
                }
            }
        }
    }

    void dumpInternal () {
        System.out.println("P1");
        TestIntMatrixMultiplication.dumpMatrix(p1);
        System.out.println("P2");
        TestIntMatrixMultiplication.dumpMatrix(p2);
        System.out.println("P3");
        TestIntMatrixMultiplication.dumpMatrix(p3);
        System.out.println("P4");
        TestIntMatrixMultiplication.dumpMatrix(p4);
        System.out.println("P5");
        TestIntMatrixMultiplication.dumpMatrix(p5);
        System.out.println("P6");
        TestIntMatrixMultiplication.dumpMatrix(p6);
        System.out.println("P7");
        TestIntMatrixMultiplication.dumpMatrix(p7);
        System.out.println("T0");
        TestIntMatrixMultiplication.dumpMatrix(t0);
        System.out.println("T1");
        TestIntMatrixMultiplication.dumpMatrix(t1);
    }
}


class Classical{
    public String getName () {
        return "classic";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        int n = a.length;

        for(int i=0; i<n; i++) {
            final int[] a_i = a[i];
            final int[] c_i = c[i];

            for(int j=0; j<n; j++) {
                int sum = 0;

                for(int k=0; k<n; k++) {
                    sum += a_i[k] * b[k][j];
                }

                c_i[j] = sum;
            }
        }

        return c;
    }
}
hammar
  • 138,522
  • 17
  • 304
  • 385
  • 2
    You are wrong somewhere. At least at the assumption 'n = 256' covers all n, at the assumption 'from somewhere' is the best possible implementation and probably on a lot of other points. – Gunther Piez Jun 06 '11 at 13:58
  • First of all that's not a correct way to benchmark algorithms in java. First read [this](http://stackoverflow.com/questions/504103/how-do-i-write-a-correct-micro-benchmark-in-java) here and check if this doesn't change the results. – Voo Jun 06 '11 at 14:15

3 Answers3

5

Issues I see:

1)Your Strassen multiply is dynamically allocating memory all the time. This is going to kill performance.

2)Your Strassen multiply should switch over to conventional multiply for small sizes rather than being recursive all the way down (though this optimization sort of invalidates your test).

3)You're matrix size may simply be too small to see the difference.

You should do comparisons with several different sizes. Perhaps 256, 512, 1024, 2048, 4096, 8192... Then plot the times and look at the trends. You will probably want matrix size on a log scale if it's all powers of 2.

Strassen is only faster for large N. How large will depend a lot on the implementation. What you have done for classical is only a basic implementation and is not optimal on a modern machine either.

phkahler
  • 5,687
  • 1
  • 23
  • 31
  • 1
    I agree generally. But 2) is one of the most important aspects for any divide and conquer algorithm (eg compare a naive merge sort to one that uses quicksort/selection sort for its leaves) so I don't see why that would invalidate the test? – Voo Jun 06 '11 at 16:14
  • If the test is to show Strassen is faster, then you want to always use it. Only AFTER you know the crossover point should it be switched depending on size. It should still be faster for large N though, so in that sense it doesn't invalidate anything. – phkahler Jun 06 '11 at 17:42
  • 1
    Surely, the switch-over doesn't invalidate the result. If a mix of A and B is faster than pure A, then B is an improvement. Consider merge sort vs. selection sort on three elements. Of course selection sort is going to win: both do 3 comparisons, but selection sort is simpler. That doesn't mean that merge sort isn't better on large arrays. – Neil G Jun 06 '11 at 18:34
2

Implementation questions aside, I think you're misunderstanding the algorithm's performance. Like phkahler said, your expectations are a little off for the performance of the algorithm. Divide-and-conquer algorithms work well for large inputs because they recursively break the problem into sub-problems which can be solved more quickly.

However, the overhead associated with this splitting action can cause the algorithm to run (sometimes much) slower for small or even medium-sized inputs. Typically, the theoretical analysis of an algorithm like Strassen will include a so-called "breakpoint" calculation. This is the input size where the overhead of splitting becomes preferable to a naive technique.

Your code needs to include a check on the size of the input that switches to the naive technique at the breakpoint.

pg1989
  • 1,010
  • 6
  • 13
1

Write down what the Strassen algorithm does for a 2 x 2 matrix. Count the operations. The number is absolutely ridiculous. It's stupid to use Strassen's method for a 2x2 matrix. Same for a 3 x 3, or 4 x 4, matrix and probably quite a way up.

gnasher729
  • 51,477
  • 5
  • 75
  • 98