The following code performs a recursive binary search over a one-dimensional numpy array of boolean values, looking for change in (true/false) state and noting the index of the array at which these changes happen.
For every change, the algorithm will output something equivalent to:
- the index before the change happens
- the index when the change happens
- the type of change happening (before-becomes-true, when-becomes-true, and so on)
It does so by detecting large spans of consecutive values (checking with the very fast implementation of numpy.all()
/ numpy.any()
and skip them, in a recursive binary-search manner, up until there are only 2 or 3 non-consecutive elements in the smallest chunks.
from numba import jit
import numpy as np
@jit(nopython=True)
def binary_search(arr, start, end):
void = [0, 0, 0, 0]
if arr[start:end].all() or not arr[start:end].any():
return void + void
elif arr[start:end].any():
if end - start == 2:
one = arr[start]
two = arr[start+1]
c = [one, two]
if c == [False, True]:
return [start, -2, start+1, -1] + void
elif c == [True, False]:
return [start, 1, start+1, 2] + void
elif end - start == 3:
one = arr[start]
two = arr[start+1]
three = arr[start+2]
c = [one, two, three]
if c == [True, True, False]:
return [start+1, 1, start+2, 2] + void
elif c == [True, False, False]:
return [start, 1, start+1, 2] + void
elif c == [False, True, True]:
return [start, -2, start+1, -1] + void
elif c == [False, False, True]:
return [start+1, -2, start+2, -1] + void
elif c == [False, True, False]:
return [start, -2, start+1, -1, start+1, 1, start+2, 2]
elif c == [True, False, True]:
return [start, 1, start+1, 2, start+1, -2, start+2, -1]
else:
mid = start + (end - start) // 2
left = binary_search(arr, start, mid)
right = binary_search(arr, mid-1, end)
return left + right
I've been using this code on the following sample generated data (made of large spans of consecutive true/false values, which is representative of the real data):
positions = np.zeros(shape=300000, dtype=np.bool)
positions[0:748] = True
positions[1305:3281] = True
positions[10938:12389] = True
positions[18392:23819] = True
positions[35884:36728] = True
positions[44847:45238] = True
# ... and so on until the end
On this sample data (300k rows), the pure-Python (not @jit
) version of binary_search()
consistently runs in about 11ms. The Numba version (with @jit
and post-compilation) consistently runs in about 5ms.
The difference between Numba and non-Numba is not even of an order of magnitude, here. I've been used to differences up to the order of 100x, leading to microseconds delays when the pure-Python version was around 11ms. I find it very surprising that in this case, Numba is not capable of performing better.
Is there a reason why my code isn't particularly accelerated by Numba? Is it because an improper use of recursion? Is it because I'm using raw Python lists
to return values between recursive calls? (and in this case, is there a way to do otherwise?)
A purely iterative version of this function (iterative in the concept, not only the implementation) where we simply loop through every element of arr
and compare it with the previous one and the next one, noting the index where changes happen, runs MUCH faster than this recursive binary search when using @jit
(becomes even truer as the array grows in size). I'm not sure to understand why this happens, as intuitively I would say the binary search method should perform in O(log n) instead of O(n) for the iterative method.
--
Note: Numba doesn't complain about binary_search
when using @jit(nopython=True)