4

I have a loss function that requires multiple internal passes:

def my_loss_func(logits, sigma, labels, num_passes):
    total_loss = 0
    img_batch_size = logits.shape[0]
    logits_shape = list(logits.shape)
    for fpass in range(num_passes):
        noise_array = torch.normal(mean=0.0, std=1.0, size=logits_shape, device=torch.device('cuda:0'))
        stochastic_output = logits + sigma * noise_array
        del noise_array
        exponent_B = torch.log(torch.sum(torch.exp(stochastic_output), dim=-1, keepdim=True))
        inner_logits = exponent_B - stochastic_output
        soft_inner_logits = labels * inner_logits
        total_loss += torch.exp(soft_inner_logits)
        del exponent_B, inner_logits, soft_inner_logits
    mean_loss = total_loss / num_passes
    actual_loss = torch.mean(torch.log(mean_loss))
    return actual_loss

Both logits and sigma are networks outputs and therefore have associated gradients. The bottleneck (expectedly) relates to the line total_loss += torch.exp(soft_inner_logits), since afaik a new computation graph is appended for subsequent passes. I've read that calling loss.backward() within the loop could help in similar circumstances, but unfortunately, I need to log the output after the loop and backdrop based on that, so this solution doesn't seem viable here.

To be more specific, I run into memory issues when num_passes exceeds 20, are there any other ways in which I could fully optimise memory allocation to allow for a greater number of passes? I’m not at all concerned with readability/ ugly solutions, any advice will be of great help.

PedsB
  • 311
  • 2
  • 9
  • Does this answer your question? [Gradient accumulation in an RNN](https://stackoverflow.com/questions/63934070/gradient-accumulation-in-an-rnn) – Szymon Maszke Oct 21 '20 at 18:48

0 Answers0