I have an array of the shape (2,10) such as:
arr = jnp.ones(shape=(2,10)) * 2
or
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
and another array, for example [2,4]
.
I want the second array to tell from which index the elements of arr
should be masked. Here the result would be:
[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
[2. 2. 2. 2. -1. -1. -1. -1. -1. -1.]]
I need to use jax.numpy
and the answer to be vectorized and fast if possible, i.e. not using loops.