2

I'm trying to get a better understanding of how Gradient Accumulation works and why it is useful. To this end, I wanted to ask what is the difference (if any) between these two possible PyTorch-like implementations of a custom training loop with gradient accumulation:

gradient_accumulation_steps = 5
for batch_idx, batch in enumerate(dataset):
  x_batch, y_true_batch = batch
  y_pred_batch = model(x_batch)

  loss = loss_fn(y_true_batch, y_pred_batch)
  loss.backward()

  if (batch_idx + 1) % gradient_accumulation_steps == 0: # (assumption: the number of batches is a multiple of gradient_accumulation_steps)
    optimizer.step()
    optimizer.zero_grad()
y_true_batches, y_pred_batches = [], []
gradient_accumulation_steps = 5
for batch_idx, batch in enumerate(dataset):
  x_batch, y_true_batch = batch
  y_pred_batch = model(x_batch)

  y_true_batches.append(y_true_batch)
  y_pred_batches.append(y_pred_batch)

  if (batch_idx + 1) % gradient_accumulation_steps == 0: # (assumption: the number of batches is a multiple of gradient_accumulation_steps)
    y_true = stack_vertically(y_true_batches)
    y_pred = stack_vertically(y_pred_batches)

    loss = loss_fn(y_true, y_pred)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    y_true_batches.clear()
    y_pred_batches.clear()

Also, kind of as an unrelated question: Since the purpose of gradient accumulation is to mimic a larger batch size in cases where you have memory constraints, does it mean that I should also increase the learning rate proportionally?

iLuvLogix
  • 5,920
  • 3
  • 26
  • 43
Francesco Cariaggi
  • 688
  • 2
  • 9
  • 23

2 Answers2

3

1. The difference between the two programs:
Conceptually, your two implementations are the same: you forward gradient_accumulation_steps batches for each weight update.
As you already observed, the second method requires more memory resources than the first one.

There is, however, a slight difference: usually, loss functions implementation use mean to reduce the loss over the batch. When you use gradient accumulation (first implementation) you reduce using mean over each mini-batch, but using sum over the accumulated gradient_accumulation_steps mini-batches. To make sure the accumulated gradient implementation is identical to large batches implementation you need to be very careful in the way the loss function is reduced. In many cases you will need to divide the accumulated loss by gradient_accumulation_steps. See this answer for a detailed imlpementation.


2. Batch size and learning rate: Learning rate and batch size are indeed related. When increasing the batch size one usually reduces the learning rate.
See, e.g.:
Samuel L. Smith, Pieter-Jan Kindermans, Chris Ying, Quoc V. Le, Don't Decay the Learning Rate, Increase the Batch Size (ICLR 2018).

Shai
  • 111,146
  • 38
  • 238
  • 371
  • Part 1: thanks for the detailed explanation. Part 2: doesn't the paper imply that *if* we're thinking of decreasing the learning rate, we should consider increasing the batch size instead? Based on what I've read about the topic, the rule of thumb is that if you increase the batch size by a factor of N, then you should increase the learning rate by a factor of either N or sqrt(N) (there's no general consensus) – Francesco Cariaggi Jan 03 '22 at 08:40
  • @FrancescoCariaggi this paper indeed claims that one should change the batch size rather than the learning rate. But the "take-home" message is a linear relation between the two. As you mentioned, there are those who consider a `sqrt` relation. I don't think there's a definitive preferred coarse of action. – Shai Jan 03 '22 at 08:44
  • You should use `sum` loss instead of `mean` when training language model with gradient accumulation step is larger than 1, because the number of tokens would vary among mini-batch, therefore the "mean of all iterations' mean of the loss of tokens in each iteration" is not equivalent to the "mean of loss of all tokens in all iterations". – Yang Bo Jul 10 '23 at 02:02
0

Gradient accumulation can be useful when you are trying to use GPU but your GPU runs out of memory. With gradient accumulation you won’t be making updates at every mini batch but after a certain frequency. There is an interesting demo exactly on this topic in Jeremy Howard fastai course in Lecture-6. Hope this helps