I have developed an image classification model using pytorch framework. I created checkpoints while model training using the code:
def save_checkpoint(model, epoch, optimizer, loss, path):
"""
:param model: model to be saved
:param epoch: epoch at which the model gets saved
:param optimizer: optimizer to compute the gradients
:param loss: loss function for the model
:param path: path for the checkpoint to be saved
:return: None
"""
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, path)
I have loaded a .pth file (checkpoint file) using the code
def load_checkpoint(path, num_classes):
checkpoint = torch.load(path)
model = models.resnet50(pretrained=True)
n_features = model.fc.in_features
model.fc = nn.Linear(n_features, num_classes)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2), eps=eps)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.train()
However, on running the checkpoint code, the training did not resume.
load_checkpoint('Checkpoint/model_epoch_2.pth', num_classes=3)
Could you kindly advise.