(EDIT: OK, now I actually had a bit more time to figure out what is going on.)
The are two issues here:
- the computational complexity depends on the sizes of both inputs and it is not captured well by a 1D benchmark plot
- the actual timing are dominated by variation in the inputs
The problem can be separated in two parts:
- looping through the rows
- performing the subset check, which is basically a nested-loop quadratic operation (in the worst-case scenario)
We know that, for sufficiently large inputs, looping through the rows is faster in NumPy and slower in pure Python.
For reference, let's consider these two approaches:
# pure Python approach
def all_in_by_row_flt(arr, elems=ELEMS):
return sum(1 for row in arr if all(e in row for e in elems))
# NumPy apprach (based on @Mstaino answer)
def all_in_by_row_np(arr, elems=ELEMS):
def _aaa_helper(row, e=elems):
return np.isin(e, row)
return np.sum(np.all(np.apply_along_axis(_aaa_helper, 1, arr), 1))
Then, considering the subset check operation, if the input is such that the check is performed within fewer loops, pure Python looping gets faster than NumPy. Conversely, if a sufficiently large number of loops is required, then NumPy can actually be faster.
On top of this, there is the looping through the rows, but because the subset check operation is quadratic AND the have different constant coefficients, there are situations for which, despite the rows-looping being faster in NumPy (because the number of rows would be sufficiently large), the overall operation is faster in pure Python.
This was the situation I was running into in the earlier benchmarks, and corresponds to the situation where the subset check is always (or almost) False
and it does fail within few loops.
As soon as the subset check starts requiring more loops, the Python only approach begins to lag behind and for the situation where the subset check is actually True
for most (if not all) the rows, the NumPy approach is actually faster.
Another key difference between the NumPy and the pure Python approach is that pure Python uses lazy evaluation and NumPy does not, and actually require the creation of potentially large intermediate objects that slow down the computation.
On top of this, NumPy iterates over the rows twice (one in sum()
and one in np.apply_along_axis()
), while the pure Python approaches only once.
Other approaches using set().issubset()
like e.g. from @GZ0 answer:
def all_in_by_row_set(arr, elems=ELEMS):
elems = set(elems)
return sum(map(elems.issubset, row))
have different timings than the explicitly writing of the nested-loop when it comes to subset checking, but they still suffer from slower outer looping.
So, what's next?
The answer is to use Cython or Numba.
The idea is to get NumPy-like (read: C) speed all the times (and not only for sufficiently large inputs), lazy evaluation and minimal number of looping through the rows.
An example of a Cython approach (as implemented in IPython, using the %load_ext Cython
magic) is:
%%cython --cplus -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True
cdef long all_in_by_row_c(long[:, :] arr, long[:] elems) nogil:
cdef long result = 0
I = arr.shape[0]
J = arr.shape[1]
K = elems.shape[0]
for i in range(I):
is_subset = True
for k in range(K):
is_contained = False
for j in range(J):
if elems[k] == arr[i, j]:
is_contained = True
break
if not is_contained:
is_subset = False
break
result += 1 if is_subset else 0
return result
def all_in_by_row_cy(long[:, :] arr, long[:] elems):
return all_in_by_row_c(arr, elems)
While a similar Numba code reads:
import numba as nb
@nb.jit(nopython=True, nogil=True)
def all_in_by_row_jit(arr, elems=ELEMS):
result = 0
n_rows, n_cols = arr.shape
for i in range(n_rows):
is_subset = True
for e in elems:
is_contained = False
for r in arr[i, :]:
if e == r:
is_contained = True
break
if not is_contained:
is_subset = False
break
result += 1 if is_subset else 0
return result
Now, time-wise we get to the following (for relatively small number of rows):
arr.shape=(100, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 120 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit 129 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt 2.44 ms ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_set 9.98 ms ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np 13.7 ms ± 52.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
arr.shape=(100, 2000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 1.45 ms ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_jit 1.52 ms ± 4.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_flt 30.1 ms ± 452 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_set 19.8 ms ± 56.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np 18 ms ± 28.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
arr.shape=(100, 3000) elems.shape=(1000,) result=37
Func: all_in_by_row_cy 10.4 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 10.9 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 226 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30.5 ms ± 92.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np 21.9 ms ± 87.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
arr.shape=(100, 4000) elems.shape=(1000,) result=86
Func: all_in_by_row_cy 16.8 ms ± 32.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 17.7 ms ± 42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 385 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 39.5 ms ± 588 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np 25.7 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now that the slow down of the last block cannot be explained by the increased input size in the second dimension.
Actually, if the short-circuit rate is increased (e.g. by changing the values range of the random arrays), for the last block (same input sizes) one gets:
arr.shape=(100, 4000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 152 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit 173 µs ± 4.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt 556 µs ± 8.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_set 39.7 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np 31.5 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Note that set()
-based method is kind of independent on the short-circuit rate (because of the hash-based implementation which has ~O(1)
check for presence complexity, but this comes at the expenses of hashing pre-computation and these results indicate this might not be faster than the direct nested-looping approach).
Finally, for larger rows counts :
arr.shape=(100000, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 141 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_jit 150 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_flt 2.6 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 10.1 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 13.7 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr.shape=(100000, 2000) elems.shape=(1000,) result=34
Func: all_in_by_row_cy 1.2 s ± 753 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 1.27 s ± 7.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 24.1 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 19.5 s ± 270 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 18 s ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr.shape=(100000, 3000) elems.shape=(1000,) result=33859
Func: all_in_by_row_cy 9.79 s ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 10.3 s ± 5.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 3min 30s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30 s ± 57.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 21.9 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr.shape=(100000, 4000) elems.shape=(1000,) result=86376
Func: all_in_by_row_cy 17 s ± 30.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 17.9 s ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 6min 29s ± 293 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 38.9 s ± 33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 25.7 s ± 29.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Finally, note that the Cython/Numba code may be algorithmically optimized.