It seems there are two issues in your implementation. First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable (i.e. the contents of the array cannot be changed).
You can overcome the two problems by combining static_argnums
and jax.lax.dynamic_update_slice
. Here is an example:
def other_fun(x):
return x + 1
@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [1 2 3 3 4]
Essentially, the example above uses static_argnums
to indicate that the function should be recompiled for different index
values and jax.lax.dynamic_update_slice
creates a copy of x
with updated values at :len(update)
.