I have recently written a script to convert BGR arrays of [0, 1] floats to HSL and back. I posted it on Code Review. There is currently one answer but it doesn't improve performance.
I have benchmarked my code against cv2.cvtColor
and found my code to be inefficient, so I want to compile the code with Numba to make it run faster.
I have tried to wrapping every function with @nb.njit(cache=True, fastmath=True)
, and this doesn't work.
So I have tested every NumPy syntax and NumPy functions I have used individually, and found two functions that don't work with Numba.
I need to find the maximum channel of each pixel (np.max(img, axis=-1)
) and minimum channel of each pixel (np.max(img, axis=-1)
), and the axis
argument doesn't work with Numba.
I have tried to Google search this but the only thing even remotely relevant I found is this, but it only implements np.any
and np.all
, and only works for two dimensional arrays whereas here the arrays are three-dimensional.
I can write a for loop based solution but I won't write it, because it is bound to be inefficient and against the purpose of using NumPy and Numba in the first place.
Minimal reproducible example:
import numba as nb
import numpy as np
@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
return np.max(arr, axis=-1)
@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
return np.min(arr, axis=-1)
img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)
Exception:
In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)
File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:
>>> amax(array(float64, 3d, C), axis=Literal[int](-1))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
With argument(s): '(array(float64, 3d, C), axis=int64)':
Rejected as the implementation raised a specific error:
TypingError: got an unexpected keyword argument 'axis'
raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)
File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
return np.max(arr, axis=-1)
^
How to fix this?