It is quite weird that numba can be so much slower.
It's not too weird. When you call NumPy functions inside a numba function you call the numba-version of these functions. These can be faster, slower or just as fast as the NumPy versions. You might be lucky or you can be unlucky (you were unlucky!). But even in the numba function you still create lots of temporaries because you use the NumPy functions (one temporary array for the dot result, one for each square and sum, one for the dot plus first sum) so you don't take advantage of the possibilities with numba.
Am I using it wrong?
Essentially: Yes.
I really need to speed this up
Okay, I'll give it a try.
Let's start with unrolling the sum of squares along axis 1 calls:
import numba as nb
@nb.njit
def sum_squares_2d_array_along_axis1(arr):
res = np.empty(arr.shape[0], dtype=arr.dtype)
for o_idx in range(arr.shape[0]):
sum_ = 0
for i_idx in range(arr.shape[1]):
sum_ += arr[o_idx, i_idx] * arr[o_idx, i_idx]
res[o_idx] = sum_
return res
@nb.njit
def euclidean_distance_square_numba_v1(x1, x2):
return -2 * np.dot(x1, x2.T) + np.expand_dims(sum_squares_2d_array_along_axis1(x1), axis=1) + sum_squares_2d_array_along_axis1(x2)
On my computer that's already 2 times faster than the NumPy code and almost 10 times faster than your original Numba code.
Speaking from experience getting it 2x faster than NumPy is generally the limit (at least if the NumPy version isn't needlessly complicated or inefficient), however you can squeeze out a bit more by unrolling everything:
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v2(x1, x2):
f1 = 0.
for i_idx in range(x1.shape[1]):
f1 += x1[0, i_idx] * x1[0, i_idx]
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
val_from_x2 = x2[o_idx, i_idx]
val += (-2) * x1[0, i_idx] * val_from_x2 + val_from_x2 * val_from_x2
val += f1
res[o_idx] = val
return res
But that only gives a ~10-20% improvement over the latest approach.
At that point you might realize that you can simplify the code (even though it probably won't speed it up):
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v3(x1, x2):
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
tmp = x1[0, i_idx] - x2[o_idx, i_idx]
val += tmp * tmp
res[o_idx] = val
return res
Yeah, that look pretty straight-forward and it's not really slower.
However in all the excitement I forgot to mention the obvious solution: scipy.spatial.distance.cdist
which has a sqeuclidean
(squared euclidean distance) option:
from scipy.spatial import distance
distance.cdist(x1, x2, metric='sqeuclidean')
It's not really faster than numba but it's available without having to write your own function...
Tests
Test for correctness and do the warmups:
x1 = np.array([[1.,2,3]])
x2 = np.array([[1.,2,3], [2,3,4], [3,4,5], [4,5,6], [5,6,7]])
res1 = euclidean_distance_square(x1, x2)
res2 = euclidean_distance_square_numba_original(x1, x2)
res3 = euclidean_distance_square_numba_v1(x1, x2)
res4 = euclidean_distance_square_numba_v2(x1, x2)
res5 = euclidean_distance_square_numba_v3(x1, x2)
np.testing.assert_array_equal(res1, res2)
np.testing.assert_array_equal(res1, res3)
np.testing.assert_array_equal(res1[0], res4)
np.testing.assert_array_equal(res1[0], res5)
np.testing.assert_almost_equal(res1, distance.cdist(x1, x2, metric='sqeuclidean'))
Timings:
x1 = np.random.random((1, 512))
x2 = np.random.random((1000000, 512))
%timeit euclidean_distance_square(x1, x2)
# 2.09 s ± 54.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_original(x1, x2)
# 10.9 s ± 158 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v1(x1, x2)
# 907 ms ± 7.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v2(x1, x2)
# 715 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v3(x1, x2)
# 731 ms ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit distance.cdist(x1, x2, metric='sqeuclidean')
# 706 ms ± 4.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Note: If you have arrays of integers you might want to change the hard-coded 0.0
in the numba functions to 0
.