2

Applying the tensorflow tutorial on how to implement a transformer model I had some doubts on the training process.

The train_step function is implemented as following :

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  with tf.GradientTape() as tape:
    predictions, _ = transformer([inp, tar_inp],
                                 training = True)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(accuracy_function(tar_real, predictions))

We can see that tf.GradientTape() is defined as tape in a with statment. That's work but I don't understand how tape can be called outside the statement with.

gradients = tape.gradient(loss, transformer.trainable_variables)

Shouldn't tape be closed at the end of the with statement? Thank you very much if you can answer my curiosity.

I implemented the code from the tutorial and it works as is.

0 Answers0