1

I'm trying to optimize matrix multiplication implemented using Java with nested loops. I'm planning to use Java 17 vector API to optimize performance.

I have read the documentation of the Vector API, but I am not sure how to apply them to my matrix multiplication code.

I'm using the following code to multiply the metrics based on this tutorial.

    int row1 = 4, col1 = 3, row2 = 3, col2 = 4;

    int A[][] = { { 1, 1, 1 },
                  { 2, 2, 2 },
                  { 3, 3, 3 },
                  { 4, 4, 4 } };

    int B[][] = { { 1, 1, 1, 1 },
                  { 2, 2, 2, 2 },
                  { 3, 3, 3, 3 } };

    int i, j, k;

    // Check if multiplication is Possible
    if (row2 != col1) {
        System.out.println("Multiplication Not Possible");
        return;
    }

    // Matrix to store the result
    // The product matrix will
    // be of size row1 x col2
    int C[][] = new int[row1][col2];

    // Multiply the two matrices
    for (i = 0; i < row1; i++) {
        for (j = 0; j < col2; j++) {
            for (k = 0; k < row2; k++)
                C[i][j] += A[i][k] * B[k][j];
        }
    }

My questions are:

  • How do I handle cases where the array length or the matrix size is not a multiple of the vector length?
  • I would appreciate some guidance or examples on how to use Java 17 vector instructions to optimize matrix multiplication.
Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
Isuru Perera
  • 351
  • 1
  • 3
  • 13
  • Are you mostly interested in very small matrices like this, like 4x4 where one SIMD vector can hold a whole row? Because that's very different from medium to large matrices where it takes a loop over many vectors to cover a whole row. (For that, the appending in https://www.akkadia.org/drepper/cpumemory.pdf has an example of vectorizing with x86 intrinsics in C; assuming Java provides a portable equivalent of that it should translate. See also [How much of ‘What Every Programmer Should Know About Memory’ is still valid?](https://stackoverflow.com/a/47714514) re: that guide) – Peter Cordes May 26 '23 at 13:18
  • 1
    If you're just trying to learn some SIMD stuff in general, matrix-multiplication is a relatively hard problem with many opportunities and challenges for optimization with SIMD and cache-blocking; it's one that's been studied a *lot* because it's so important for a lot of use-cases. (The row vs. column layout makes it a disaster to stride down one column in a non-tiny matrix, but doing along rows in both and also the output is a lot of load/store work.) So writing your own code that comes anywhere close to a well-tuned BLAS library is a big challenge. – Peter Cordes May 26 '23 at 13:24
  • Some interesting but non-trivial problems to SIMD vectorize include prefix-sum ([Accumulating a running-total (prefix sum) horizontally across an \_\_m256i vector](https://stackoverflow.com/q/69694914)), counting matches ([How to count character occurrences using SIMD](https://stackoverflow.com/q/54541129)), turning 16 bytes into 32 ASCII hex digits ([How to convert a binary integer number to a hex string?](https://stackoverflow.com/q/53823756)) – Peter Cordes May 26 '23 at 13:24
  • I'm mostly interested in multiplying medium to large matrices where rows and columns don't fit in a single SIMD register. But I think Java vector API handles mapping large vectors into AVX instructions. – Isuru Perera May 26 '23 at 13:25
  • Ok, so your real problem is *very* different from a case like 3x4 * 4x3 where a vector will hold two whole rows or more, and you'd be doing a lot of shuffling of 32-byte vectors to arrange the data. (If you search on SO for existing questions about 3x3 and 4x4 matmul with SIMD intrinsics for x86 SSE/AVX or ARM NEON, you'll find examples of that.) – Peter Cordes May 26 '23 at 13:28
  • 1
    What have you tried that didn't work? – aled May 26 '23 at 16:24
  • Hi @aled I've added the current approach I'm trying here. (https://github.com/isuruperera/j17-vector-matmul/blob/main/src/main/java/org/example/Main.java) but in that code, SPECIES.loopBound() method always returns 0. I'm having a hard time understanding what the documentation means by " the largest multiple of the vector length not greater than the given length" would that mean the HW level vector length? – Isuru Perera May 27 '23 at 04:21
  • I noticed that the above code works for vectors larger than 16, but for vectors smaller than 16 elements, SPECIES.loopBound() method returns 0 – Isuru Perera May 27 '23 at 04:51
  • "the largest multiple of the vector length not greater than the given length" is `n` (array size) rounded down to a multiple of 4 for example. `n - (n % vec_len)` which simplifies to `n & -vec_len` since the SIMD vector length will always be a power of 2. It's the number of elements you can process with a loop like `for (i=0 ; i < n-(vec_len-1) ; i += vec_len)` that does only full vectors, and doesn't read past the end of the array in the final iteration. Notice how for vec_len=1, that loop would be `for (i=0 ; i – Peter Cordes May 27 '23 at 05:37
  • It's the same problem you'd run into when unrolling manually. For learning basics like that, I'd recommend just writing a function that sums an array and works for any array length, or that does `a[i] = b[i] * constant + constant2` or something. (Fun fact, there's a name for that as a level-1 BLAS operation, DAXPY, as in `D = A*X + Y` where D and A are arrays.) – Peter Cordes May 27 '23 at 05:42
  • Hi @PeterCordes, I have implemented this simple L1 BLAS code that does `a[i] = b[i] * constant + constant2` here (https://github.com/isuruperera/j17-vector-matmul/blob/main/src/main/java/org/example/VectorToScalarMultiplication.java) I'm assuming Vector * Scalar operation would be something SIMD is very good at. I also went through your answer here (https://stackoverflow.com/questions/68061254/why-is-the-java-vector-api-so-slow-compared-to-scalar). However, after repeated tests with fairly large vectors, I'm seeing the loop performs better (Assume there's manual unrolling in JVM). – Isuru Perera May 27 '23 at 06:26
  • I also verified that my CPU supports AVX, AVX2, AVX512F. But theoretically vector computations should run faster? But from my tests, I'm seeing otherwise. – Isuru Perera May 27 '23 at 06:27
  • Simple "vertical" operations like DAXPY might be so easy that the JVM auto-vectorizes scalar code. Or if you're testing on large arrays (larger than your L2 cache size), you might bottleneck on memory bandwidth even with scalar asm, especially if you're using an 8-byte type like `double`, and/or your CPU is an Intel Xeon where memory latency is high and single-core bandwidth is low (especially on Skylake-avx512 / Cascade Lake.) See [Why is Skylake so much better than Broadwell-E for single-threaded memory throughput?](//stackoverflow.com/q/39260020) for client vs. server per-core bandwidth – Peter Cordes May 27 '23 at 06:47
  • Also, with 512-bit vectors on Intel CPUs, if your data isn't aligned by 64, that actually hurts memory bandwidth from L3 or DRAM for some reason, about 15% slower than aligned, vs. only a couple % slower with 256-bit vectors. You'd expect that the extra costs of misaligned loads/stores would be hidden by out-of-order exec and you'd still just bottleneck on the full memory bandwidth a single core can use (thanks to hardware prefetch), but that's not the case. – Peter Cordes May 27 '23 at 06:49

0 Answers0