I modify the most efficient code from (Why this numba code is 6x slower than numpy code?) so that it can handle x1 being (n, m)
@nb.njit(fastmath=True,parallel=True)
def euclidean_distance_square_numba_v5(x1, x2):
res = np.empty((x1.shape[0], x2.shape[0]), dtype=x2.dtype)
for a_idx in nb.prange(x1.shape[0]):
for o_idx in range(x2.shape[0]):
val = 0.
for i_idx in range(x2.shape[1]):
tmp = x1[a_idx, i_idx] - x2[o_idx, i_idx]
val += tmp * tmp
res[a_idx, o_idx] = val
return res
However, it is still not more efficient that the more efficient numpy's version:
def euclidean_distance_square_einsum(x1, x2):
return np.einsum('ij,ij->i', x1, x1)[:, np.newaxis] + np.einsum('ij,ij->i', x2, x2) - 2*np.dot(x1, x2.T)
With input as
a = np.zeros((1000000,512), dtype=np.float32)
b = np.zeros((100, 512), dtype=np.float32)
The timing I got is 2.4723422527313232 for the numba code and 0.8260958194732666 for the numpy code.