1

I have a boolean matrix in numpy with shape (m, n).

I initialize the matrix elements to be False.

I want to randomly set exactly x elements in each row (x < n) with the value True.

Now I go over the matrix with a loop, using np.random.choice with no replacement:

mat = np.full((M, N), fill_value=False)
for i in range(mat.shape[0]):
    mat[i, np.random.choice(mat.shape[1], x, replace=False)] = True

Is there a more efficient way to do this with numpy?

user107511
  • 772
  • 3
  • 23
  • 1
    You can use `np.add.at` but [it is not very efficient](https://stackoverflow.com/a/72048786/12939557) so you can follow the same approach: just use Numba the same way. What is `x` in practice? If it is small, then the implementation of `np.random.choice` is sub-optimal. – Jérôme Richard May 04 '22 at 08:22
  • `N` is quite large (~1000) and `x` small (~10). `numba` also implements `choice` (only without probabilities), will that approach be better than the `numba` implementation? – user107511 May 04 '22 at 08:40
  • 1
    Try with the new random API: https://stackoverflow.com/questions/40914862/why-is-random-sample-faster-than-numpys-random-choice – ayhan May 04 '22 at 09:16
  • @ayhan I work with `default_rng` already – user107511 May 04 '22 at 09:26

4 Answers4

5

np.random.choice is suboptimal when the number of value to pick is small compared to the size of the array. This is because the current implementation use a partitioning method. A faster implementation consist in picking some random positions, hen check is there are duplicates and repeat this process until all the positions are different (which is very likely when x/N is very small (when x/N < 0.05, the probability to generate correct numbers per iteration is >0.95). Numba can speed up this process. Here is the resulting code:

import numba as nb

@nb.njit('(int_, int_, int_[::1])')
def pick(x, N, out):
    assert out.size == x
    if x / N <= 0.05:
        while True:
            for j in range(x):
                out[j] = np.random.randint(0, N)
            out.sort()
            ok = True
            for i in range(x-1):
                if out[i] == out[i+1]:
                    ok = False
            if ok: return
    out[:] = np.random.choice(N, x, replace=False)

@nb.njit('bool_[:,::1](int_, int_, int_)')
def compute(M, N, x):
    mat = np.zeros((M, N), dtype=np.bool_)
    cols = np.empty(x, np.int_)
    for i in range(M):
        pick(x, N, cols)
        for j in cols:
            mat[i, j] = True
    return mat

N, M = 1000, 1000
x = 10
mat = compute(M, N, x)

An even faster and simpler approach is to set directly the values in the array as proposed by Kelly Bundy. This as the benefit of avoiding a slow sort operation. Here is the resulting code:

import numba as nb

@nb.njit('bool_[:,::1](int_, int_, int_)')
def compute(M, N, x):
    mat = np.zeros((M, N), dtype=np.bool_)
    for i in range(M):
        if x/N <= 0.20:
            k = 0
            while k < x:
                j = np.random.randint(0, N)
                if not mat[i, j]:
                    mat[i, j] = True
                    k += 1
        else:
            for j in np.random.choice(N, x, replace=False):
                mat[i, j] = True
    return mat

N, M = 1000, 1000
x = 10
mat = compute(M, N, x)

This is 276 times faster than the initial approach on my machine and also much faster than the other answers.

Results

Initial:         27.61 ms
Salvatore D.B.:  20.54 ms
D.Manasreh:      14.90 ms
Numba V1:         0.66 ms
Numba V2:         0.10 ms  <---
Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
  • How about simply setting x elements to True, using the row to check whether an element is already True? [Pure Python version](https://tio.run/##PU3BCsIwDL2/r8hxEwRFBBG8euzJ29hhYOcKW1OySuvX12hhgfDyXt5LwidO7E@XIKWMwgvJ4J8Kbgks8c@0XxYwdKPjAVnhDAgnHbr7MK@2px0ZIE1utpSvIC2n2y3cmLaKI3nWo5w611ffryrXwEPedlMz7fUhEMT52KinLeUL), as I'm not familiar with Numba. – Kelly Bundy May 04 '22 at 10:17
  • @KellyBundy. Indeed. I was unsure about whether the distribution was perfectly uniform but this finally Ok in practice. This is even faster. Thank you. – Jérôme Richard May 04 '22 at 11:14
  • It's the same approach as what Python's [`random.sample`](https://github.com/python/cpython/blob/178a238f25ab8aff7689d7a09d66dc1583ecd6cb/Lib/random.py#L496-L503) might do, so it better be correct :-). Btw, would it be even faster to store `np.random.randint` and the row in variables, to save the attribute lookups and to only do 1-dimensional indexing, or does Numba optimize those things anyway? – Kelly Bundy May 04 '22 at 11:19
  • This is great, and surely `numba` helps here. The only disadvantage is that `numba`'s random seed is global, so if I have several threads running, each with its own random state, that wouldn't work. – user107511 May 04 '22 at 13:21
  • 1
    @KellyBundy Good to know `random.sample` use an efficient algorithm though it is a bit sad it is not written in C. For the attribute lookups, Numba should optimize those things (for the 1D-indexing it is often optimized but sometime the JIT generate an inefficient code but this is not a limiting factor here). Most of the time is now mainly spent in the assignment (that cause slow cache misses and often page faults but it cannot be avoided here AFAIK) and a bit in the RNG. – Jérôme Richard May 04 '22 at 18:32
  • @user107511 Numba use a thread-local seed so it is not shared between Numba threads. Note that the Numba seed is not synchronized with the one of Numpy (they used two different variable and in fact Numba have its own implementation of Numpy functions). Note that if you spawn non-numba threads and use Numba functions, then you can set the seed in the Numba function and pass it in parameter ;). – Jérôme Richard May 04 '22 at 18:34
2

Try something like this:

import numpy as np

m = 100
n = 100
x = 60

# generate a matrix of shape (m * n) with x true values per row
mat0 = np.tile(np.repeat([True, False], [x, (n-x)]), [m, 1])

# permute on rows
mat = np.random.default_rng().permuted(mat0, axis=1)

D.Manasreh
  • 900
  • 1
  • 5
  • 9
2

This solution is heavily inspired by Jérôme Richard's implementation of Kelly Bundy's approach, but with guaranteed x iterations per row. I don't know why it is slower than their x/N <= .2 branch.

import numba as nb # tested with numba 0.55.1
import numpy as np

@nb.njit('bool_[:,::1](int_, int_, int_)')
def compute1(M, N, x):
    mat = np.zeros((M, N), dtype=np.bool_)
    for i in range(M):
        for j in range(N-x, N):
            y = np.random.randint(j+1)
            if mat[i, y]: y = j
            mat[i, y] = True
    return mat
Michael Szczesny
  • 4,911
  • 5
  • 15
  • 32
  • How much slower is it? Looks nice, haven't seen that before, will have to think it through. – Kelly Bundy May 04 '22 at 17:35
  • @KellyBundy - about 8% slower for different test data on a 2-core colab instance. Using a while loop reduces it to ~6% (but looks ugly). Not sure if it needs to do more or if it can't be optimized by the compiler. – Michael Szczesny May 04 '22 at 17:49
  • How about [these variations](https://tio.run/##K6gsycjPM7YoKPr/XwEK0vKLFLIUMvMUihLz0lM1/HQrtA11FPy0DTWtuBSQQKWCrUJegR5QVUp@LpjKzCvRyNJEUZSZppCbWBKdqaNQGWsFY2bFArWGFJWmoihNzSlOtUKohinhoom7oNZoZKG4EOwGhUpNmOX//wMA)? – Kelly Bundy May 04 '22 at 18:13
  • Both are not correct (the samples are not uniformly distributed (notice `randint(j+1)`, but setting `y = j`). This is the hard part to get right. You can see me struggling in the revision history myself) and not faster in my benchmarks. – Michael Szczesny May 04 '22 at 18:32
  • I cannot see any significant different in term of performance between the two code on my Linux machine (that physically fill the array to zeros in practice). On Windows, I can see the gap you mention (and my Windows does not physically allocate the array to zeros). So I though the difference on Windows was due the page faults. However, there is a 7-8% difference on my Windows between the two codes when the array is preallocated & filled. In the end, it looks like this is due to the optimization (something affect the JIT optimization in the code or maybe it is just an heuristic issue). – Jérôme Richard May 04 '22 at 21:30
  • 1
    The main loop assembly code for your function seems more expensive/bigger than mine (at least on Windows). I am not exactly sure why... Here is the generated code: [compute](https://pastebin.com/acEynp5t) & [compute1](https://pastebin.com/QygAQ1s8). Note for example that the rax register is reloaded in the hot loop for no apparent reason. That being said, I think the `j+1` in the RNG makes the code a bit slower because the range is not a constant and thus some expression might not be precomputed once. – Jérôme Richard May 04 '22 at 21:41
  • 1
    Another reason is that `y` may cause additional instructions to be executed if the JIT cannot prove at compile time that it is always in the range 0..N (because Numba need to check for special cases like negative values). This unfortunately often happens with Numba. In fact, the code was 20 times slower if Numba does not know the size of the `mat` array in the function (using an assert or by creating the array in the function with a `(M, N)` shape). Btw, note that the number of iterations is the same for the two implementation for the provided inputs (thanks to x being small). – Jérôme Richard May 04 '22 at 21:49
  • @MichaelSzczesny Ah right, need to adjust `j` to `j-1`. (I had thought of it but then forgot, and only being on the phone I didn't test them). Anyway, the idea is to do the `-1` only in the rare exception, instead of the `+1` every time. – Kelly Bundy May 04 '22 at 22:51
0
#input
M, N, x = 100,9,3
mat = np.full((M, N), fill_value=False)

#solution
mat[np.repeat(np.arange(M), x), np.ravel([np.random.permutation(N)[:x] for i in range(M)])]=True

Output:

array([[False,  True,  True, False, False, False, False, False,  True],
   [ True,  True, False,  True, False, False, False, False, False],
   [False,  True, False, False, False, False,  True, False,  True],
   [ True,  True, False,  True, False, False, False, False, False],
   [False, False, False,  True, False, False,  True, False,  True],
   [False,  True, False,  True, False,  True, False, False, False],
   [False, False, False,  True, False,  True,  True, False, False],
   [ True, False, False, False, False,  True,  True, False, False],
   [ True, False, False,  True,  True, False, False, False, False],
   [ True, False, False, False, False, False,  True,  True, False],
   [ True,  True, False, False, False,  True, False, False, False],
   [ True,  True, False, False, False, False,  True, False, False],
   [ True, False,  True, False, False, False, False,  True, False],
   [False, False, False, False, False,  True,  True,  True, False],
   [False, False,  True,  True, False,  True, False, False, False],
   [False, False, False,  True, False,  True,  True, False, False],
   [False,  True, False,  True, False, False, False, False,  True],
   [False, False, False, False,  True,  True, False,  True, False],
   [False, False, False, False, False,  True,  True, False,  True],
   [ True, False,  True, False, False, False, False, False,  True],
   [False,  True, False, False, False, False, False,  True,  True],
   [ True, False, False, False, False,  True, False,  True, False],
   [False, False,  True, False, False, False, False,  True,  True],
   [ True, False, False, False, False,  True,  True, False, False],
   [ True,  True, False, False, False, False, False,  True, False],
   [False, False, False,  True, False, False,  True, False,  True],
   [False, False,  True,  True, False, False, False,  True, False],
   [False, False, False, False, False,  True, False,  True,  True],
   [False, False, False, False,  True,  True, False, False,  True],
   [False, False,  True, False, False,  True, False,  True, False],
   [False,  True, False, False,  True,  True, False, False, False],
   [False, False,  True,  True, False, False,  True, False, False],
   [False, False, False,  True, False,  True, False,  True, False],
   [ True, False, False, False, False, False,  True,  True, False],
   [False,  True, False, False,  True, False, False, False,  True],
   [False, False, False,  True, False, False, False,  True,  True],
   [False, False,  True, False,  True,  True, False, False, False],
   [False,  True,  True, False, False, False, False, False,  True],
   [False, False,  True,  True, False, False,  True, False, False],
   [False, False, False,  True, False,  True,  True, False, False],
   [False,  True,  True, False, False, False, False, False,  True],
   [False, False,  True, False, False,  True, False,  True, False],
   [False,  True, False, False,  True, False,  True, False, False],
   [False,  True, False, False, False, False, False,  True,  True],
   [False, False, False,  True, False, False, False,  True,  True],
   [ True, False, False,  True,  True, False, False, False, False],
   [False, False, False,  True, False,  True, False, False,  True],
   [False, False,  True, False,  True, False,  True, False, False],
   [ True, False, False, False, False, False,  True, False,  True],
   [ True,  True, False, False, False,  True, False, False, False],
   [False, False,  True,  True, False, False,  True, False, False],
   [False, False,  True,  True,  True, False, False, False, False],
   [False, False,  True, False,  True, False, False, False,  True],
   [False, False,  True, False, False,  True,  True, False, False],
   [ True, False, False,  True, False,  True, False, False, False],
   [ True, False, False, False,  True, False, False,  True, False],
   [False,  True,  True, False, False, False,  True, False, False],
   [False, False, False, False, False, False,  True,  True,  True],
   [ True, False,  True, False, False, False,  True, False, False],
   [False,  True,  True, False, False, False, False,  True, False],
   [False,  True, False,  True, False, False, False,  True, False],
   [False,  True, False,  True,  True, False, False, False, False],
   [ True, False, False, False, False, False,  True,  True, False],
   [ True, False,  True, False, False, False,  True, False, False],
   [False, False,  True, False, False, False,  True,  True, False],
   [False, False, False, False, False,  True, False,  True,  True],
   [False, False, False, False, False,  True,  True,  True, False],
   [False, False, False, False, False,  True,  True,  True, False],
   [ True,  True, False, False,  True, False, False, False, False],
   [ True, False,  True, False, False, False,  True, False, False],
   [False, False, False, False, False, False,  True,  True,  True],
   [False,  True, False, False,  True,  True, False, False, False],
   [False,  True, False,  True, False, False, False,  True, False],
   [False, False, False, False,  True,  True, False, False,  True],
   [ True,  True, False, False, False, False,  True, False, False],
   [False,  True, False, False,  True, False, False,  True, False],
   [False, False,  True,  True, False,  True, False, False, False],
   [False,  True, False,  True, False,  True, False, False, False],
   [False, False, False,  True, False,  True, False, False,  True],
   [ True, False, False, False, False, False,  True, False,  True],
   [False, False, False,  True, False, False,  True,  True, False],
   [False, False, False, False,  True,  True,  True, False, False],
   [False, False, False,  True, False, False, False,  True,  True],
   [False,  True, False, False, False,  True,  True, False, False],
   [False, False, False,  True, False,  True, False,  True, False],
   [False,  True, False, False,  True, False, False,  True, False],
   [False,  True, False, False,  True, False, False, False,  True],
   [False, False, False, False, False,  True, False,  True,  True],
   [ True, False, False, False,  True, False, False,  True, False],
   [ True, False,  True, False, False,  True, False, False, False],
   [False, False,  True, False, False, False,  True, False,  True],
   [False, False,  True,  True, False,  True, False, False, False],
   [False, False, False, False,  True,  True,  True, False, False],
   [False, False, False, False,  True, False, False,  True,  True],
   [False, False, False, False,  True, False,  True, False,  True],
   [False,  True,  True,  True, False, False, False, False, False],
   [ True, False, False, False, False,  True, False, False,  True],
   [ True,  True, False,  True, False, False, False, False, False],
   [ True, False, False, False,  True, False, False, False,  True],
   [ True, False, False,  True,  True, False, False, False, False]])

The time required in my machine is 0.002034902572631836 seconds, vs your solution that requires 0.0050237178802490234

#check the results
(mat.sum(1)==3).all() #True
mat.sum(0) #array([34, 30, 33, 46, 23, 35, 36, 31, 32])