4

I want to save the best model and then load it during the test. So I used the following method:

def train():  
    #training steps …  
    if acc > best_acc:  
        best_state = model.state_dict()  
        best_acc = acc
    return best_state  

Then, in the main function I used:

model.load_state_dict(best_state)  

to resume the model.

However, I found that best_state is always the same as the last state during training, not the best state. Is anyone know the reason and how to avoid it?

By the way, I know I can use torch.save(the_model.state_dict(), PATH) and then load the model by the_model.load_state_dict(torch.load(PATH)). However, I don’t want to save the parameters to file as train and test functions are in one file.

prosti
  • 42,291
  • 14
  • 186
  • 151
Lei_Bai
  • 78
  • 7

2 Answers2

6

model.state_dict() is OrderedDict

from collections import OrderedDict

You can use:

from copy import deepcopy

To fix the problem

Instead:

best_state = model.state_dict() 

You should use:

best_state = copy.deepcopy(model.state_dict())

Deep (not shallow) copy makes the mutable OrderedDict instance not to mutate best_state as it goes.

You may check my other answer on saving the state dict in PyTorch.

prosti
  • 42,291
  • 14
  • 186
  • 151
0

When you are saving the state of the model you should save the following things in the network

1) Optimizer state and 2) Model's state dict

You can define one method in your class model as following

def save_state(state,filename):
    torch.save(state,filename)

''' When you are saving the state do as follows: '''

Model model //for example  
model.save_state({'state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()}) 

The saved model will be stored as model.pth.tar (for an example)

Now during loading do the following steps,

checkpoint = torch.load('model.pth.tar')         

model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

Hope this will help you.

Saurav Rai
  • 2,171
  • 1
  • 15
  • 29
  • 1
    This answer is good fit for [this question](https://stackoverflow.com/q/42703500/5884955), but not perfect fit for this one. – prosti Jun 12 '19 at 14:19
  • 1
    Thanks, I know I can save the state to a file. But I prefer Saurav's answer – Lei_Bai Jun 13 '19 at 11:02
  • @ Lei_Bai Very happy to help you in a small way. Hopefully, I will be of your help in the future as well. – Saurav Rai Jun 13 '19 at 11:19