0

I'm finding that after I save and later load my pytorch model the performance decreases substantially on both train and test sets. I'm currently training my model on CIFAR10. Below is the code I run to save and then load the model.

Save:

if save_model:
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            }, state_path)

load:

model = ViT_model.ViT(image_size = image_size, patch_size = patch_size, num_classes = 10, dim = dim, depth = numblocks, mlp_dim = dim, attention_type = 'multi_head_q', 
            heads = heads, dropout = dropout, emb_dropout = dropout, fixed_size = False, pre_layers = pre_layers)
model= nn.DataParallel(model)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr = initial_lr, betas=(0.9, 0.99), weight_decay = 5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max= epochs, eta_min= 1e-6)
if load_model:
    checkpoint = torch.load(state_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    starting_epoch = checkpoint['epoch'] + 1
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    print(f'Loaded model at epoch {starting_epoch}')

Accuracy immediately drops on the loaded model and never really recovers.

new2java
  • 39
  • 5

0 Answers0