0

I have developed an image classification model using pytorch framework. I created checkpoints while model training using the code:

def save_checkpoint(model, epoch, optimizer, loss, path):
    """
    :param model: model to be saved
    :param epoch: epoch at which the model gets saved
    :param optimizer: optimizer to compute the gradients
    :param loss: loss function for the model
    :param path: path for the checkpoint to be saved
    :return: None
    """

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, path) 

I have loaded a .pth file (checkpoint file) using the code

def load_checkpoint(path, num_classes):
    checkpoint = torch.load(path)
    model = models.resnet50(pretrained=True)
    n_features = model.fc.in_features
    model.fc = nn.Linear(n_features, num_classes)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2), eps=eps)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()

However, on running the checkpoint code, the training did not resume.

load_checkpoint('Checkpoint/model_epoch_2.pth', num_classes=3)

Could you kindly advise.

Saksham
  • 65
  • 8
  • what's the error? – kkgarg Dec 20 '21 at 15:15
  • There is no error at all. There is nothing in console that says that training has been resumed. – Saksham Dec 20 '21 at 15:18
  • 1
    "There is nothing in console that says that training has been resumed." Because 1. model.train() does not start to train the model, see https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch/51433411#51433411 2. You do not print the training is resumed, hence it does not say training is resumed. – Umang Gupta Dec 20 '21 at 18:15
  • Also you're creating a model and optimizer and loading the saved state dict but not returning them. So you won't have access to the loaded model or optimizer outside of the load_checkpoint function. You should either pass in the existing model and optimizer into your load function and then load the state dicts into those or return the new model and optimizer to the calling function (you could also use the global keyword but this is generally considered poor design unless there is a good reason). – jodag Dec 21 '21 at 00:12

0 Answers0