I have a function compute(x)
where x
is a jnp.ndarray
. Now, I want to use vmap
to transform it into a function that takes a batch of arrays x[i]
, and then jit
to speed it up. compute(x)
is something like:
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
However, each array x[i]
has a different length. I can easily work around this problem by padding arrays with trailing zeros such that they all have the same length N
and vmap(compute)
can be applied on batches with shape (batch_size, N)
.
Doing so, however, leads to very_expensive_function()
to be called also on the trailing zeros of each array x[i]
. Is there a way to modify compute()
such that very_expensive_function()
is called only on a slice of x
, without interfering with vmap
and jit
?