-1

I'm trying to train a model in Pytorch, and I'd like to have a batch size of 8, but due to memory limitations, I can only have a batch size of at most 4. I've looked all around and read a lot about accumulating gradients, and it seems like the solution to my problem.

However, I seem to have trouble implementing it. Every time I run the code I get RuntimeError: Trying to backward through the graph a second time. I don't understand why since my code looks like all these other examples I've seen (unless I'm just missing something major):

  1. https://stackoverflow.com/a/62076913/1227353
  2. https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255
  3. https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20

One caveat is that the labels for my images are all different size, so I can't send the output batch and the label batch into the loss function; I have to iterate over them together. This is what an epoch looks like (it's been pared down for the sake of brevity):

  # labels_batch contains labels of different sizes
  for batch_idx, (inputs_batch, labels_batch) in enumerate(dataloader):
    outputs_batch = model(inputs_batch)

    # have to do this because labels can't be stacked into a tensor
    for output, label in zip(outputs_batch, labels_batch):
      output_scaled = interpolate(...)  # make output match label size
      loss = train_criterion(output_scaled, label) / (BATCH_SIZE * 2)
      loss.backward()

    if batch_idx % 2 == 1:
      optimizer.step()
      optimizer.zero_grad()

Is there something I'm missing? If I do the following I also get an error:

  # labels_batch contains labels of different sizes
  for batch_idx, (inputs_batch, labels_batch) in enumerate(dataloader):
    outputs_batch = model(inputs_batch)

    # CHANGE: we're gonna accumulate losses manually
    batch_loss = 0

    # have to do this because labels can't be stacked into a tensor
    for output, label in zip(outputs_batch, labels_batch):
      output_scaled = interpolate(...)  # make output match label size
      loss = train_criterion(output_scaled, label) / (BATCH_SIZE * 2)
      batch_loss += loss # CHANGE: accumulate!

    # CHANGE: do backprop outside for loop
    batch_loss.backward()

    if batch_idx % 2 == 1:
      optimizer.step()
      optimizer.zero_grad()

The error I get in this case is RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. This happens when the next epoch starts though... (INCORRECT, SEE EDIT BELOW)

How can I train my model with gradient accumulation? Or am I doomed to train with a batch size of 4 or less?

Oh and as a side question, does the location of where I put loss.backward() affect what I need to normalize the loss by? Or is it always normalized by BATCH_SIZE * 2?

EDIT:

The second code segment was getting an error due to the fact that I was doing torch.set_grad_enabled(phase == 'train') but I had forgotten to wrap the call to batch_loss.backward() with an if phase == 'train'... my bad

So now the second segment of code seems to work and do gradient accumulation, but why doesn't the first bit of code work? It feel equivalent to setting BATCH_SIZE as 1. Furthermore, I'm creating a new loss object each time, so shouldn't the calls to backward() operate on different graphs entirely?

rcplusplus
  • 2,767
  • 5
  • 29
  • 43
  • If you call `.backward` twice on the same graph, or part of the same graph you will get "Trying to backward through the graph a second time". But you could accumulate the loss in a tensor and then, only when you're done call `.backward` on it. In your last example, you may try to call `batch_loss.backward()` (instead of `loss.backward()`) after the loop on the *zip* has ended. – Ivan Dec 13 '20 at 09:02
  • @Ivan whoops that was a typo in my post, I was actually calling backward on batch loss there and getting the error i mentioned – rcplusplus Dec 13 '20 at 16:40

1 Answers1

1

It seems you have two issues here, you said you couldn't have batch_size=8 because of memory limitations but later state that your labels are not of the same size. The latter seems much more important than the former. Anyway, I will try to answer your questions best I can.


How can I train my model with gradient accumulation? Or am I doomed to train with a batch size of 4 or less?

You want to call .backward() on every loop cycle otherwise the batch will have no effect on the training. You can then call step() and zero_grad() only when batch_idx % 2 is True (i.e. for every other batch).

Here's an example which accumulates the gradient, not the loss:

model = nn.Linear(10, 3)
optim = torch.optim.SGD(model.parameters(), lr=0.1)

ds = TensorDataset(torch.rand(100, 10), torch.rand(100, 3))
dl = DataLoader(ds, batch_size=4)

for i, (x, y) in enumerate(dl):
    y_hat = model(x)
    loss = F.l1_loss(y_hat, y) / 2
    loss.backward()

    if i % 2:
        optim.step()
        optim.zero_grad()

Note this approach is different to accumulating the loss, and back-propagating only all batches (or part of the batches) have gone through the network. In the example above we backpropagate every 4 datapoints and updating the model every 8 datapoints.


Oh and as a side question, does the location of where I put loss.backward() affect what I need to normalize the loss by? Or is it always normalized by BATCH_SIZE * 2?

Usually torch's built-in losses have reduction='mean' set as default. This means the loss gets averaged over all batch elements that contributed to calculating the loss. So this will depend on your loss implementation.

However if you are using gradient accumalation, then yes you will need to average your loss by the number of accumulation steps (here loss = F.l1_loss(y_hat, y) / 2). Since your gradients will be accumulated twice.

To read more about this, I recommend taking a look at this other SO post.

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Unfortunately there are two loops in my training code, one for each batch, and then one for each output in a batch. In both the inner and outer loops I'm getting errors, though two different ones. Why can't I call backward in the inner loop? And what's that weird error I'm getting when I call it in the outer loop? – rcplusplus Dec 13 '20 at 18:08
  • You cannot call `.backward()` twice on the same graph (on the same loss function). Maybe what you're looking for is a combination of gradient accumulation (in outer loop) and loss accumulation (in the inner loop). Let me edit my answer. – Ivan Dec 13 '20 at 18:20
  • ah it turns out there was an unrelated bug with `set_grad_enabled` that causes my second code segment to failed. That one works now! But how come I can't call `loss.backward()` in the inner loop like in my first code segment? I'm reconstructing the `loss` object each time so it should be a different graph right? If i set `BATCH_SIZE` to 1, it'd basically be the same thing right? – rcplusplus Dec 13 '20 at 18:26
  • Which snippet exactly, are you referring to? – Ivan Dec 13 '20 at 18:28
  • My second snippet (the one where I accumulate into `batch loss`) works just fine, but my first snippet (the one where I call `loss.backward()` inside the inner loop) seems to fail but I'm not sure why it should. It's a different `loss` object each time and it feels equivalent to me using the second snippet but setting batch size to 1 – rcplusplus Dec 13 '20 at 18:32
  • It won't work since you have called `model(inputs_batch)` previous to entering your inner loop. Therefore, all batch elements will share the same activation graph. This means calling `.backward` (even on two simingly different loss tensors) will result in an error. If you want to do this, you will have to call model for each batch element *separately* inside that inner loop. So you'll be looping over `zip(inputs_batch, labels_batch)` instead of `zip(outputs_batch, labels_batch)`. – Ivan Dec 13 '20 at 18:38
  • Oh I see... Just to clarify, suppose I have `input1` and `input2` in `input_batch` and I get `output1, output2 = model(input_batch)`. Even if I do `criterion(output1, label1).backward()`, it destroys the graph for `criterion(output2, label2)` because they share an underlying graph? – rcplusplus Dec 13 '20 at 18:43
  • 1
    Calling `.backward` will free the graph yes, so the second `.backward` call on the second loss will raise `RuntimeError: Trying to backward through the graph a second time`. You could specify `retain_graph=True` on your first call to keep the graph in memory. But that would defeat the purpose of saving memory in the first place! – Ivan Dec 13 '20 at 18:48
  • That makes so much sense, thanks!! I was tearing my hair out all last night haha – rcplusplus Dec 13 '20 at 18:50