With integers, numpy doesn't make use of BLAS for some reason. source
import numpy as np
from numba import njit
def matrix_multiplication(A, B):
m, n = A.shape
_, p = B.shape
C = np.zeros((m, p))
for i in range(m):
for j in range(n):
for k in range(p):
C[i, k] += A[i, j] * B[j, k]
return C
@njit()
def matrix_multiplication_optimized(A, B):
m, n = A.shape
_, p = B.shape
C = np.zeros((m, p))
for i in range(m):
for j in range(n):
for k in range(p):
C[i, k] += A[i, j] * B[j, k]
return C
m = 100
n = 100
p = 100
A = np.random.randint(1, 100, size=(m,n))
B = np.random.randint(1, 100, size=(n, p))
# compile function
matrix_multiplication_optimized(A, B)
%timeit matrix_multiplication(A, B)
%timeit matrix_multiplication_optimized(A, B)
%timeit A @ B
685 ms ± 7.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 ms ± 5.51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.49 ms ± 19.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In this case, numba is even a little bit faster than numpy. This leads me to think that numba is generating code that uses vectorization while also being cache friendly (the python code can't be improved any further). Other loop orders are worse, so I might have used the correct cache friendly loop order without realizing it.
@njit()
def matrix_multiplication_optimized2(A, B):
m, n = A.shape
_, p = B.shape
C = np.zeros((m, p))
for j in range(n):
for k in range(p):
for i in range(m):
C[i, k] += A[i, j] * B[j, k]
return C
@njit()
def matrix_multiplication_optimized3(A, B):
m, n = A.shape
_, p = B.shape
C = np.zeros((m, p))
for k in range(p):
for i in range(m):
for j in range(n):
C[i, k] += A[i, j] * B[j, k]
return C
m = 1000
n = 1000
p = 1000
A = np.random.randn(m, n)
B = np.random.randn(n, p)
# compile function
matrix_multiplication_optimized(A, B)
matrix_multiplication_optimized2(A, B)
matrix_multiplication_optimized3(A, B)
%timeit matrix_multiplication_optimized(A, B)
%timeit matrix_multiplication_optimized2(A, B)
%timeit matrix_multiplication_optimized3(A, B)
%timeit A @ B
1.45 s ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
12.6 s ± 92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.93 s ± 35.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
30 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In my experience, numpy is about 50 times faster than numba with floating point numbers. This question shows how using BLAS improves performance. The numba documentation mentions BLAS at the end, but I don't know how to use numpy.linalg
.
Commenting out the line C[i, j] = tmp
made the temporary variable useless.
@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot(A, B):
k=A.shape[1]
C = np.zeros((k, k))
for i in range(k):
for j in range(k):
tmp = 0.
for l in range(k):
tmp += A[i, l] * B[l, j]
# C[i, j] = tmp
return C
@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot2(A, B):
k=A.shape[1]
C = np.zeros((k, k))
for i in range(k):
for j in range(k):
# tmp = 0.
for l in range(k):
# tmp += A[i, l] * B[l, j]
pass
# C[i, j] = tmp
return C
%timeit numba_dot(A, B)
%timeit numba_dot2(A, B)
for k, v in numba_dot.inspect_asm().items():
print(k, v)
2.59 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.6 ms ± 93.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I can't read the generated code, but the temporary variable was probably removed during optimization since it wasn't used.
C[i, j] = i * j
can be performed relatively quickly. Keep in mind that vectorized operations are being used.
4.18 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Your implementation was slower than mine, so I tried reversing l and j.
@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot(A, B):
k=A.shape[1]
C = np.zeros((k, k))
for i in range(k):
for j in range(k):
tmp = 0.
for l in range(k):
tmp += A[i, l] * B[l, j]
C[i, j] = tmp
return C
@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot2(A, B):
k=A.shape[1]
C = np.zeros((k, k))
for i in range(k):
for l in range(k):
tmp = 0.
for j in range(k):
tmp += A[i, l] * B[l, j]
C[i, j] = tmp
return C
%timeit numba_dot(A, B)
%timeit numba_dot2(A, B)
3.16 s ± 36.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.57 s ± 24.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
When doing that, it doesn't really make sense to keep a temporary variable since j is the last loop. I don't see any issue with updating C[i, j] directly.