Applying the tensorflow tutorial on how to implement a transformer model I had some doubts on the training process.
The train_step function is implemented as following :
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
with tf.GradientTape() as tape:
predictions, _ = transformer([inp, tar_inp],
training = True)
loss = loss_function(tar_real, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(accuracy_function(tar_real, predictions))
We can see that tf.GradientTape() is defined as tape in a with statment. That's work but I don't understand how tape can be called outside the statement with.
gradients = tape.gradient(loss, transformer.trainable_variables)
Shouldn't tape be closed at the end of the with statement? Thank you very much if you can answer my curiosity.
I implemented the code from the tutorial and it works as is.