5

I have an array of the shape (2,10) such as:

arr = jnp.ones(shape=(2,10)) * 2

or

[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]

and another array, for example [2,4].

I want the second array to tell from which index the elements of arr should be masked. Here the result would be:

[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
 [2. 2. 2. 2.  -1. -1. -1. -1. -1. -1.]]

I need to use jax.numpy and the answer to be vectorized and fast if possible, i.e. not using loops.

Valentin Macé
  • 1,150
  • 1
  • 10
  • 25
  • 1
    With `numpy` it's straight forward: `for i in range(2): arr[i, idx[i]:] = -1`. Were you looking for some personal magic? – hpaulj Jun 08 '22 at 23:16
  • Thanks for your answer, however I was looking for something vectorized and fast using jax.numpy – Valentin Macé Jun 08 '22 at 23:42
  • I, and possibly others, like to see a more explicit question, one that shows what you are doing - even if it iterates as I suggest. Don't leave the magical "vectorized" goal unstated! – hpaulj Jun 08 '22 at 23:45
  • I do agree. I edited my question for potential future readers. – Valentin Macé Jun 08 '22 at 23:50
  • In some quick testing, the relative speed of the loop versus the masked `where`, depends on the array shape (n,m). Where `m` is relatively large, using `n` slices is faster. – hpaulj Jun 09 '22 at 05:43

1 Answers1

3

You can do this with a vmapped three-term jnp.where statement. For example:

import jax.numpy as jnp
import jax

arr = jnp.ones(shape=(2,10)) * 2
idx = jnp.array([2, 4])

@jax.vmap
def f(row, ind):
  return jnp.where(jnp.arange(len(row)) < ind, row, -1)

f(arr, idx)
# DeviceArray([[ 2.,  2., -1., -1., -1., -1., -1., -1., -1., -1.],
#              [ 2.,  2.,  2.,  2., -1., -1., -1., -1., -1., -1.]], dtype=float32)
jakevdp
  • 77,104
  • 11
  • 125
  • 160