2

I try to find an explanation why my matrix multiplication with Numba is much slower than using NumPy's dot function. Although I am using the most basic code for writing a matrix multiplication function with Numba, I don't think that the significantly slower performance is due to the algorithm. For simplicity, I consider two k x k square matrices, A and B. My code reads

1     @njit('f8[:,:](f8[:,:], f8[:,:])')
2     def numba_dot(A, B):
3
4         k=A.shape[1]
5         C = np.zeros((k, k))
6
7         for i in range(k):
8             for j in range(k):
9
10                 tmp = 0.
11                for l in range(k):
12                    tmp += A[i, l] * B[l, j]
13     
14                C[i, j] = tmp
15
16         return C

Running this code repeatedly with two random matrices 1000 x 1000 Matrices, it typically takes at least about 1.5 seconds to finish. On the other hand, if I don't update the matrix C, i.e. if I drop line 14, or replace it for the sake of a test by for example the following line:

14                C[i, j] = i * j

the code finishes in about 1-5 ms. Compared to that, NumPy's dot function requires for this matrix multiplication around 10 ms.

What is the reason behind the discrepancy of the running times between the above code for the matrix multiplication and this small variation? Is there a way to store the value of the variable tmp in C[i, j] without deteriorating the performance of the code so significantly?

Marc
  • 21
  • 1
  • Your algorithm is absolutely not optimized. A real world example on how to implement matrix multiplication looks for example like that https://gist.github.com/nadavrot/5b35d44e8ba3dd718e595e40184d03f0 Numpy calls a BLAS function dgemm in this case. Numba will do the same if the inputs are contiguous. eg. `@njit('f8[:,::1](f8[:,::1], f8[:,::1])')` – max9111 Apr 06 '21 at 07:28
  • Thanks for your reply. For some reason also with contiguous inputs I get similar running times. – Marc Apr 07 '21 at 19:23
  • Just call np.dot in Numba (with contiguous arrays). In both cases numpy and numba will do quite the same (calling an external BLAS library). The link was just to show how complicated real world matrix multiplication is. It is a good learning, exampe but if you just wan't to calculate a dot product, this is the way to do it. You can also try it in C. (It will still be slower by more than 100 times without some improvements to the algorithm). Also consider that compilers try to optimize away useless parts. The whole inner loop is detected as useless if you write C[i, j] = i * j. – max9111 Apr 07 '21 at 19:46

2 Answers2

2

The native NumPy implementation works with vectorized operations. If your CPU supports these, the processing is much faster. Current microprocessors have on-chip matrix multiplication, which pipelines the data transfers and vector operations.

Your implementation performs k^3 loop iterations; a billion of anything will take some non-trivial time. Your code specifies that you want to perform each cell-by-cell operation in isolation, a billion distinct operations instead of roughly 5k operations done in parallel and pipelined.

Prune
  • 76,765
  • 14
  • 60
  • 81
  • Thank you for the answer. I think that my example shows that it is not just the number of operations that have to be executed but the type of operations. When modifying the code as described and using Numba to compile the code the three loops can be executed in a time similar to NumPy's dot function. – Marc Apr 06 '21 at 08:44
0

With integers, numpy doesn't make use of BLAS for some reason. source

import numpy as np
from numba import njit

def matrix_multiplication(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for i in range(m):
    for j in range(n):
      for k in range(p):
        C[i, k] += A[i, j] * B[j, k]
  return C

@njit()
def matrix_multiplication_optimized(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for i in range(m):
    for j in range(n):
      for k in range(p):
        C[i, k] += A[i, j] * B[j, k]
  return C

m = 100
n = 100
p = 100
A = np.random.randint(1, 100, size=(m,n))
B = np.random.randint(1, 100, size=(n, p))

# compile function
matrix_multiplication_optimized(A, B)

%timeit matrix_multiplication(A, B)
%timeit matrix_multiplication_optimized(A, B)
%timeit A @ B
685 ms ± 7.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 ms ± 5.51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.49 ms ± 19.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In this case, numba is even a little bit faster than numpy. This leads me to think that numba is generating code that uses vectorization while also being cache friendly (the python code can't be improved any further). Other loop orders are worse, so I might have used the correct cache friendly loop order without realizing it.

@njit()
def matrix_multiplication_optimized2(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for j in range(n):
    for k in range(p):
      for i in range(m):
        C[i, k] += A[i, j] * B[j, k]
  return C

@njit()
def matrix_multiplication_optimized3(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for k in range(p):
    for i in range(m):
      for j in range(n):
        C[i, k] += A[i, j] * B[j, k]
  return C
m = 1000
n = 1000
p = 1000
A = np.random.randn(m, n)
B = np.random.randn(n, p)

# compile function
matrix_multiplication_optimized(A, B)
matrix_multiplication_optimized2(A, B)
matrix_multiplication_optimized3(A, B)


%timeit matrix_multiplication_optimized(A, B)
%timeit matrix_multiplication_optimized2(A, B)
%timeit matrix_multiplication_optimized3(A, B)
%timeit A @ B
1.45 s ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
12.6 s ± 92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.93 s ± 35.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
30 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In my experience, numpy is about 50 times faster than numba with floating point numbers. This question shows how using BLAS improves performance. The numba documentation mentions BLAS at the end, but I don't know how to use numpy.linalg.

Commenting out the line C[i, j] = tmp made the temporary variable useless.

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for j in range(k):
          tmp = 0.
          for l in range(k):
              tmp += A[i, l] * B[l, j]

          # C[i, j] = tmp

    return C

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot2(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for j in range(k):
          # tmp = 0.
          for l in range(k):
              # tmp += A[i, l] * B[l, j]
              pass

          # C[i, j] = tmp

    return C

%timeit numba_dot(A, B)
%timeit numba_dot2(A, B)
for k, v in numba_dot.inspect_asm().items():
  print(k, v)
2.59 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.6 ms ± 93.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

I can't read the generated code, but the temporary variable was probably removed during optimization since it wasn't used.

C[i, j] = i * j can be performed relatively quickly. Keep in mind that vectorized operations are being used.

4.18 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Your implementation was slower than mine, so I tried reversing l and j.

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for j in range(k):
          tmp = 0.
          for l in range(k):
              tmp += A[i, l] * B[l, j]

          C[i, j] = tmp

    return C

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot2(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for l in range(k):
          tmp = 0.
          for j in range(k):
              tmp += A[i, l] * B[l, j]
              C[i, j] = tmp

    return C



%timeit numba_dot(A, B)
%timeit numba_dot2(A, B)
3.16 s ± 36.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.57 s ± 24.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

When doing that, it doesn't really make sense to keep a temporary variable since j is the last loop. I don't see any issue with updating C[i, j] directly.

BPDev
  • 397
  • 1
  • 9