I encountered the same issue and decided to do some experimenting. I discovered that, if the axis axis
can be an integer literal (i.e. it is known ahead of time and doesn't need to be retrieved from a variable), there is a Numba-compatible alternative. That being said, I also found these solutions to be slower with JIT compilation in my test function, so be sure to benchmark your function to make sure there is an actual net improvement if you want to use this.
As others have pointed out, Numba doesn't support the axis
argument of several NumPy functions including np.all
. The first potential solution I thought of is np.amin
(aka np.ndarray.min
): np.all(a, axis=axis)
is identical to np.amin(a, axis=axis)
for boolean arrays, and identical to np.amin(a, axis=axis).astype('bool')
for numerical arrays. Unfortunately, np.amin
is also in the list of functions for which the axis
argument is not supported. However, np.argmin
does support the axis
argument, and so does np.take_along_axis
.
Therefore, np.all(a, axis=axis)
can be replaced with
For numeric arrays:
np.take_along_axis(a, np.expand_dims(np.argmin(a, axis=axis), axis), axis)[
(:, ){axis}
0].astype('bool')
For boolean arrays:
np.take_along_axis(a, np.expand_dims(np.argmin(a.astype('int64'), axis=axis), axis), axis)[
(:, ){axis}
0]
The separated parts, (:, ){axis}
, should be replaced with axis
repetitions of :,
so that the correct axis is eliminated. For example, if a
is a boolean array and axis
is 2
, you would use
np.take_along_axis(a, np.expand_dims(np.argmin(a.astype('int64'), axis=2), 2), 2)[:, :, 0]
.
Benchmarks
All I can say about this is, if you really need a numpy.all
alternative within a function that overall would highly benefit from JIT compilation, this solution is suitable. If you're really just looking to speed up all
by itself, you won't have much luck.
test.py
import numba
import numpy as np
# @numba.njit # raises a TypingError
def using_all():
n = np.arange(10000).reshape((-1, 5)) # numeric array
b = n < 4888 # boolean array
return (np.all(n, axis=1),
np.all(b, axis=1))
# @numba.njit # raises a TypingError
def using_amin():
n = np.arange(10000).reshape((-1, 5)) # numeric array
b = n < 4888 # boolean array
return (np.amin(n, axis=1).astype('bool'),
np.amin(b, axis=1))
@numba.njit # doesn't raise a TypingError
def using_take_along_axis():
n = np.arange(10000).reshape((-1, 5)) # numeric array
b = n < 4888 # boolean array
return (np.take_along_axis(n, np.expand_dims(np.argmin(n, axis=1), 1), 1)[:, 0].astype('bool'),
np.take_along_axis(b, np.expand_dims(np.argmin(b.astype('int64'), axis=1), 1), 1)[:, 0])
if __name__ == '__main__':
a = using_all()
m = using_amin()
assert np.all(a[0] == m[0])
assert np.all(a[1] == m[1])
t = using_take_along_axis()
assert np.all(a[0] == t[0])
assert np.all(a[1] == t[1])
PS C:\> python -m timeit -n 10000 -s 'from test import using_all; using_all()' 'using_all()'
10000 loops, best of 5: 32.9 usec per loop
PS C:\> python -m timeit -n 10000 -s 'from test import using_amin; using_amin()' 'using_amin()'
10000 loops, best of 5: 43.5 usec per loop
PS C:\> python -m timeit -n 10000 -s 'from test import using_take_along_axis; using_take_along_axis()' 'using_take_along_axis()'
10000 loops, best of 5: 55.4 usec per loop