I want to save the best model and then load it during the test. So I used the following method:
def train():
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state
Then, in the main function I used:
model.load_state_dict(best_state)
to resume the model.
However, I found that best_state is always the same as the last state during training, not the best state. Is anyone know the reason and how to avoid it?
By the way, I know I can use torch.save(the_model.state_dict(), PATH)
and then load the model by
the_model.load_state_dict(torch.load(PATH))
.
However, I don’t want to save the parameters to file as train and test functions are in one file.