1

I have a model architecture. I have saved the entire model using torch.save() for some n number of iterations. I want to run another iteration of my code by using the pre-trained weights of the model I saved previously.

Edit: I want the weight initialization for the new iteration be done from the weights of the pretrained model

Edit 2: Just to add, I don't plan to resume training. I intend to save the model and use it for a separate training with same parameters. Think of it like using a saved model with weights etc. for a larger run and more samples (i.e. a complete new training job)

Right now, I do something like:

# default_lr = 5
# default_weight_decay = 0.001
# model_io = the pretrained model 
model = torch.load(model_io) 
optim = torch.optim.Adam(model.parameters(),lr=default_lr, weight_decay=default_weight_decay)  
loss_new = BCELoss()  
epochs = default_epoch 
.
.
training_loop():
....
outputs = model(input)
....
.
#similarly for test loop

Am I missing something? I have to run for a very long epoch for a huge number of sample so can not afford to wait to see the results then figure out things.

Thank you!

Sulphur
  • 514
  • 6
  • 24

1 Answers1

0

From the code that you have posted, I see that you are only loading the previous model parameters in order to restart your training from where you left it off. This is not sufficient to restart your training correctly. Along with your model parameters (weights), you also need to save and load your optimizer state, especially when your choice of optimizer is Adam which has velocity parameters for all your weights that help in decaying the learning rate.

In order to smoothly restart training, I would do the following:

# For saving your model

state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}
model_save_path = "Enter/your/model/path/here/model_name.pth"
torch.save(state, model_save_path)

# ------------------------------------------

# For loading your model
state = torch.load(model_save_path)

model = MyNetwork()
model.load_state_dict(state['model'])

optim = torch.optim.Adam(model.parameters(),lr=default_lr, weight_decay=default_weight_decay)
optim.load_state_dict(state['optimizer'])

Besides these, you may also want to save your learning rate if you are using a learning rate decay strategy, your best validation accuracy so far which you may want for checkpointing purposes, and any other changeable parameter which might affect your training. But in most of the cases, saving and loading just the model weights and optimizer state should be sufficient.

EDIT: You may also want to look at this following answer which explains in detail how you should save your model in different scenarios.

ntd
  • 2,052
  • 3
  • 13
  • 19
  • Thank you for the comment. I don't intend to change my optimizer and learning rate for this new iteration. In short, all parameters stay same. wold just want to ensure it starts of from the weights saved. It was my understanding that torch.save would save the entire model so simply using that model would enable to use the weights as the starting point. Is my understanding incorrect? – Sulphur Jun 02 '20 at 03:46
  • Yes, that's correct. I have also added another link to my original answer which may help you further. – ntd Jun 02 '20 at 04:31
  • Thank you for the link. From Jadiel de Armas answer (towards the bottom) and Case # 3 of his answer, will my answer be correct if all my parameters are same? i.e. Given my parameters are same, I assume `torch.save` would save everything including the architecture, weights etc and I can simply load the model as model = load(.pth model) and run the training? I am not using any learning decay strategy too. – Sulphur Jun 02 '20 at 07:01
  • Just to add, I don't plan to resume training. I intend to save the model and use it for a separate training with same parameters. Think of it like using a saved model with weights etc for a larger run and more samples. – Sulphur Jun 02 '20 at 07:08
  • Yes, ``` torch.save() ``` would save your entire model architecture and the weights of your model and you should be able to load your model using ``` torch.load() ```, but just make sure that you do not change your Network architecture code significantly as mentioned in the following git repo: https://github.com/pytorch/pytorch/blob/761d6799beb3afa03657a71776412a2171ee7533/docs/source/notes/serialization.rst – ntd Jun 02 '20 at 08:23
  • Thanks! No, I have 0 changes except for number of samples I am using. – Sulphur Jun 02 '20 at 16:27