Loss functions in pytorch use "mean" reduction. So it means that the model gradient will have roughly the same magnitude given any batch size. It makes sense that you want to scale the learning rate up when you increase batch size because your gradient doesn't become bigger as you increase batch size.
For gradient accumulation in PyTorch, it will "sum" the gradient N times where N is the number of times you call backward()
before you call step()
. My intuition is that this would increase the magnitude of the gradient and you should reduce the learning rate, or at least not increase it.
But I saw people wrote multiplication to gradient accumulation steps in this repo:
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
I also see similar code in this repo:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
I understand why you want to increase the learning rate by batch size. But I don't understand why they try to increase the learning rate by the number of accumulation steps.
- Do they divide the loss by N to reduce the magnitude of the gradient? Otherwise why do they multiply learning rate by the accumulation steps?
- How are gradients from different GPUs accumulated? Is it using mean or sum? If it's sum, why are they multiplying the learning rate by nGPUs?