I'm familiarizing myself with the Trax library for building deep learning models and one question that I can't find an answer to is how to switch from "train" mode to "eval" mode after model training is complete.
Consider this example with a transformer (modified from https://github.com/google/trax#1-run-a-pre-trained-transformer)
import trax
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=2048, mode='train')
# do the training
After the training, how can I switch the mode to be "predict"?
One idea that comes to mind (following the same example), is to train the model, then save it, re-initialize the model but this time with mode='train'
and then load the weights by running model.init_from_file(file)
.
Is there a way to do it directly without having to re-init the model?