0

I'd like to speed up the cdist between two numpy.ndarray using numba as follows:

import numpy as np
from numba import njit, jit
from scipy.spatial.distance import cdist
import time

@njit
def dist_scipy(a, b):
    d = cdist(a, b, 'euclidean')
    d = np.transpose(d)
    sorted_d = np.sort(d)
    sorted_ind = np.argsort(d)
    return sorted_d, sorted_ind

def get_a_b(r=10**4,c=10** 1):
    a = np.random.uniform(-1, 1, (r, c)).astype('f')
    b = np.random.uniform(-1, 1, (r, c)).astype('f')
    return a,b

if __name__ == "__main__":
    a, b = get_a_b()
    st_t = time.time()
    dist_scipy(a,b)
    print('it took {} s'.format(time.time()-st_t))

In python2, after $ pip install numba-scipy, I get the following error:

Traceback (most recent call last):
  File "stackoverflow_Q.py", line 31, in <module>
    dist_scipy(a,b)
  File "/usr/local/lib/python2.7/dist-packages/numba/dispatcher.py", line 420, in _compile_for_args
    raise e
  File "/usr/local/lib/python2.7/dist-packages/numba_scipy/special/overloads.py", line 12
    f = signatures.name_and_types_to_pointer[(name, *signature)]
                                                    ^
SyntaxError: invalid syntax

And in python3, after $ conda install -c conda-forge numba numba-scipy, I get the following error:

Traceback (most recent call last):
  File "numba_scipy_test.py", line 31, in <module>
    dist_scipy(a,b)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 420, in _compile_for_args
    raise e
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 353, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 768, in compile
    cres = self._compiler.compile(args, return_type)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 77, in compile
    status, retval = self._compile_cached(args, return_type)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 91, in _compile_cached
    retval = self._compile_core(args, return_type)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/dispatcher.py", line 109, in _compile_core
    pipeline_class=self.pipeline_class)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler.py", line 550, in compile_extra
    args, return_type, flags, locals)
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/compiler.py", line 281, in __init__
    targetctx.refresh()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/targets/base.py", line 281, in refresh
    self.load_additional_registries()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/targets/cpu.py", line 80, in load_additional_registries
    numba.entrypoints.init_all()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/entrypoints.py", line 24, in init_all
    func()
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/__init__.py", line 12, in _init_extension
    from . import special
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/__init__.py", line 1, in <module>
    from . import overloads as _overloads
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/overloads.py", line 4, in <module>
    from . import signatures
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba_scipy/special/signatures.py", line 376, in <module>
    ('pdtr', numba.types.float64, numba.types.float64): ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)(get_cython_function_address('scipy.special.cython_special', '__pyx_fuse_0pdtr')),
  File "/home/alijani/.conda/envs/py3_gpu/lib/python3.7/site-packages/numba/extending.py", line 406, in get_cython_function_address
    return _import_cython_function(module_name, function_name)
ValueError: No function '__pyx_fuse_0pdtr' found in __pyx_capi__ of 'scipy.special.cython_special'
Farid Alijani
  • 839
  • 1
  • 7
  • 25
  • It isn't yet supported. (Actually there is only most of scipy.special supported). A possible implementation: https://stackoverflow.com/a/60854278/4045774 – max9111 Apr 03 '20 at 12:21
  • what doe ur implementation do? I couldn't follow ur code snippet? are you using `scipy.spatial.distance` ? that's actually where numba struggles! – Farid Alijani Apr 03 '20 at 13:24
  • As said, `scipy.spatial.distance.cdist` isn't supported by Numba, but it is very easy to implement it. A few minor changes to `dist_arr_1(A)` and you have a replacement for that unsupported function. – max9111 Apr 03 '20 at 13:35

0 Answers0