TL;DR
In general np.random.choice(replace=False)
performs well and generally better than random.sample()
(which cannot be used inside Numba JITted functions in NoPython mode anyway), but for smaller value of k
it is best to use random_sample_shuffle_idx()
, except when k
is very small, in which case random_sample_set()
should be used.
Parallel Numba compilation may speed up execution for sufficiently large inputs, but will result in race conditions potentially invalidating the sampling, and it is therefore best avoided.
Discussion
A simple Numba compatible re-implementation of random.sample()
can be easily written:
import random
import numba as nb
import numpy as np
@nb.njit
def random_sample_set(arr, k=-1):
n = arr.size
if k < 0:
k = arr.size
seen = {0}
seen.clear()
index = np.empty(k, dtype=arr.dtype)
for i in range(k):
j = random.randint(i, n - 1)
while j in seen:
j = random.randint(0, n - 1)
seen.add(j)
index[i] = j
return arr[index]
This uses a temporary set()
named seen
which stores all the indices seen previously and avoids re-using them.
This should always be faster than random.sample()
.
A potentially much faster version of this can be written by just shuffling (a portion of a copy of) the input:
import random
import numba as nb
@nb.njit
def random_sample_shuffle(arr, k=-1):
n = arr.size
if k < 0:
k = arr.size
result = arr.copy()
for i in range(k):
j = random.randint(i, n - 1)
result[i], result[j] = result[j], result[i]
return result[:k].copy()
This is faster than random_sample_set()
as long as the k
parameter is sufficiently large.
The larger the input array, the larger the k
parameter needs to be to outperform random_sample_set()
.
This is so because random_sample_set()
will have a high collision rate of j
in seen
, causing the while
-loop to run multiple times on average.
Also, this has a memory footprint independent of k
, which is higher than that of random_sample_set()
, whose memory footprint is proportional to k
.
A slight variation of random_sample_shuffle()
is random_sample_shuffle_idx()
which uses indices instead of copying the input.
This would have a memory footprint independent of the data type, being more efficient for larger data types, and significantly faster for small values of k
and typically on par (or slightly slower) in the general case:
import random
import numpy as np
import numba as nb
@nb.njit
def random_sample_shuffle_idx(arr, k=-1):
n = arr.size
if k < 0:
k = arr.size
index = np.arange(n)
for i in range(k):
j = random.randint(i, n - 1)
index[i], index[j] = index[j], index[i]
return arr[index[:k]]
When comparing the above with np.random.choice(replace=False)
:
import numpy as np
import numba as nb
@nb.njit
def random_sample_choice(arr, k=-1):
if k < 0:
k = arr.size
return np.random.choice(arr, k, replace=False)
one would observe that this sits between random_sample_set()
and random_sample_shuffle()
when it comes to speed, as long as the input is sufficiently small and k
is not too small.
Of course np.random.choice()
and its newer counterpart random.Generator.choice()
offer a lot more functionality than these simple implementations.
Benchmarks
Some quick benchmarks can be generated with the following:
funcs = random_sample_set, random_sample_shuffle
def is_good(x):
return len(x) == len(set(x))
for q in range(4, 24, 4):
n = 2 ** q
arr = np.arange(n)
seq = arr.tolist()
for k in range(n // 8, n + 1, n // 8):
print(f"n = {n}, k = {k}")
func = random.sample
print(f"{func.__name__:>24s} {is_good(func(seq, k))!s:>5s}", end=" ")
%timeit -n 1 -r 1 func(seq, k)
for func in funcs:
print(f"{func.__name__:>24s} {is_good(func(arr, k))!s:>5s}", end=" ")
%timeit -n 4 -r 4 func(arr, k)
The most interesting results are:
...
n = 65536, k = 65536
sample True 41 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 22 ms ± 1 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 1 ms ± 113 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 948 µs ± 94 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 918 µs ± 67.7 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
...
n = 1048576, k = 131072
sample True 136 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 19 ms ± 1.84 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 5.85 ms ± 303 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 6.95 ms ± 445 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 26.1 ms ± 1.93 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
...
n = 1048576, k = 917504
sample True 916 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 313 ms ± 47.6 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 29.4 ms ± 1.87 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 32.8 ms ± 1.55 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 28.2 ms ± 1.06 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
...
And the small k
regime (which is the use-case of the question):
for q in range(4, 28, 4):
n = 2 ** q
arr = np.arange(n)
seq = arr.tolist()
k = 16
print(f"n = {n}, k = {k}")
func = random.sample
print(f"{func.__name__:>24s} {is_good(func(seq, k))!s:>5s}", end=" ")
%timeit -n 1 -r 1 func(seq, k)
for func in funcs:
print(f"{func.__name__:>24s} {is_good(func(arr, k))!s:>5s}", end=" ")
%timeit -n 4 -r 4 func(arr, k)
n = 16, k = 16
sample True 39.1 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 5.11 µs ± 2.95 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 2.62 µs ± 1.67 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 2.51 µs ± 1.47 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 2.39 µs ± 1.47 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 256, k = 16
sample True 43.7 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 3.67 µs ± 2.36 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 2.59 µs ± 1.72 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True The slowest run took 4.44 times longer than the fastest. This could mean that an intermediate result is being cached.
2.8 µs ± 2.16 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 5.47 µs ± 1.8 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 4096, k = 16
sample True 33.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 3.53 µs ± 1.73 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 4.2 µs ± 1.81 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 3.23 µs ± 1.46 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 51.7 µs ± 4.82 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 65536, k = 16
sample True 58.9 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 4.15 µs ± 2.75 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 35.3 µs ± 7.99 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 15.1 µs ± 5.03 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True The slowest run took 5.93 times longer than the fastest. This could mean that an intermediate result is being cached.
2.3 ms ± 1.61 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 1048576, k = 16
sample True 48.2 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 3.89 µs ± 2.01 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 1.87 ms ± 163 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 810 µs ± 195 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 30.6 ms ± 2.41 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
n = 16777216, k = 16
sample True 70.4 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
random_sample_set True 5.15 µs ± 3.65 µs per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle True 103 ms ± 1.84 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_shuffle_idx True 75.2 ms ± 3.31 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)
random_sample_choice True 863 ms ± 77.3 ms per loop (mean ± std. dev. of 4 runs, 4 loops each)