0

I am learning seq2seq model using https://github.com/keon/seq2seq. I have successfully run the original project. Then, I want to train my-self translation model.

For myself data, the following code is OK. But for second batch, out of memory is reported for loss.backward. The second batch data is smaller than the first one.

    src, trg = src.cuda().T, trg.cuda().T
    optimizer.zero_grad()
    output = model(src, trg)
    loss = F.nll_loss(output[1:].view(-1, vocab_size),
                           trg[1:].contiguous().view(-1),
                           ignore_index=pad)
    loss.backward(retain_graph=True)
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()
    total_loss += loss.data.item()
    torch.cuda.empty_cache()

For batch size = 16, the above code is OK for first batch, and out of memory is reported at loss.backward for second batch.

The used gpu for first batch is:

2233 MB src, trg = src.cuda().T, trg.cuda().T
2331 MB output = model(src, trg)
4772 MB loss.backward(retain_graph=True)
6471 MB optimizer.step()
5312 MB torch.cuda.empty_cache()

For batch size = 1, the above code is OK for first batch, and out of memory is reported at loss.backward for second batch.

Any suggestion is appreciated!

Qiang Zhang
  • 820
  • 8
  • 32

1 Answers1

0

To expand slightly on @akshayk07 's answer, you should change the loss line to loss.backward() retaining the loss graph requires storing additional information about the model gradient, and is only really useful if you need to backpropogate multiple losses through a single graph. By default, pytorch automatically clears the graph after a single loss value is backpropogated to free up memory. For more on retain_graph see What does the parameter retain_graph mean in the Variable's backward() method?. As a rule of thumb, you should only make a backward() call with retain_graph = True if you plan to make another backward() call without retain_graph = False on the same batch.

Likely the empty_cache operation does not recognize that the loss graph -allocated memory is no longer needed, so it does not free this memory after each batch. Retaining these graphs will quickly fill up the GPU memory.

DerekG
  • 3,555
  • 1
  • 11
  • 21
  • I use retain_graph=True because I tried loss.backward() twice for the first batch data. For the second time, an error is reported: Trying to backward through the graph a second time ... Thus, I use loss.backward(retain_graph). However, even using loss.backward(), OOM is reported for the second batch size. It make me crazy. – Qiang Zhang May 04 '21 at 01:10