I have tried to run the code. Here, there is a command called n_jitted_steps=5
, which according to the authors, can accumulate several steps. Since the code is rather complicated, it might be difficult to understand. However, I have tried the following command here in Colab, where the relevant cell is
@jax.jit(n_jitted_steps=5)
def train_step(state, batch):
"""Train for a single step."""
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']).mean()
return loss
grad_fn = jax.grad(loss_fn)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
Obviously, this creates an error. However, I wonder
- Is the function of
n_jitted_steps=5
to send run five steps in one go, probably similar to loop unrolling? - If that is the case, what is the correct way to use it?
Thanks in advance.