1

I am training my network with early stopping strategy. I start with a higher learning rate, and based on validation loss, I need to restart training from an earlier snapshot.

I am able to save/load snapshot with model and optimizer state_dicts. No problem with that.

My question is, once I restart training, how do I set the learning rate of adam again? Should I restart adam fresh instead of using a state_dict or should I use optimizer.param_groups[0][‘lr’] = lr to adjust learning rate with loaded optimizer state_dict?

For example, I train my network with lr = 1e-6 for 5 epochs, saved model and optimizer state_dict. I am now restarting from epoch 6, but I need lr = 1e-7 instead. What is the best approach for this?

Thanks!

Anonymou
  • 25
  • 5
Xoul
  • 359
  • 1
  • 3
  • 13

2 Answers2

1

Looking at PyTorch's torch.optim.lr_scheduler code here, I can see that they set the parameter of the optimizer. Thus, that will be the best approach. The exact place I can see this is in step function of class _LRScheduler (in the above link).

You can do the same by

optimizer.param_groups[0]['lr'] = lr

as you had mentioned yourself.

akshayk07
  • 2,092
  • 1
  • 20
  • 32
  • I dived further into the code, and I think I figured out the correct way to do this. Please check my answer. – Xoul Oct 09 '19 at 11:09
  • No need to do all that, that single line `optimizer.param_groups[0]['lr'] = lr` is enough at the point where you need to change the learning rate. They have done it to formally define it in a class. – akshayk07 Oct 09 '19 at 11:15
  • Interesting. Does that work even if there are multiple parameters set when initializing optimizer? – Xoul Oct 10 '19 at 13:11
  • Yes, I think. It should work for any other parameters also (like weight decay or momentum). Only the parameter you change will be changed, others remain the same. – akshayk07 Oct 10 '19 at 13:27
1

Looking further into the scheduler code, I found the correct way to do it as:

def get_lr(gamma, optimizer):
    return [group['lr'] * gamma
            for group in optimizer.param_groups]

for param_group, lr in zip(optimizer.param_groups, get_lr(gamma, optimizer)):
    param_group['lr'] = lr
Xoul
  • 359
  • 1
  • 3
  • 13