I have a loss function that requires multiple internal passes:
def my_loss_func(logits, sigma, labels, num_passes):
total_loss = 0
img_batch_size = logits.shape[0]
logits_shape = list(logits.shape)
for fpass in range(num_passes):
noise_array = torch.normal(mean=0.0, std=1.0, size=logits_shape, device=torch.device('cuda:0'))
stochastic_output = logits + sigma * noise_array
del noise_array
exponent_B = torch.log(torch.sum(torch.exp(stochastic_output), dim=-1, keepdim=True))
inner_logits = exponent_B - stochastic_output
soft_inner_logits = labels * inner_logits
total_loss += torch.exp(soft_inner_logits)
del exponent_B, inner_logits, soft_inner_logits
mean_loss = total_loss / num_passes
actual_loss = torch.mean(torch.log(mean_loss))
return actual_loss
Both logits and sigma are networks outputs and therefore have associated gradients. The bottleneck (expectedly) relates to the line total_loss += torch.exp(soft_inner_logits)
, since afaik a new computation graph is appended for subsequent passes. I've read that calling loss.backward() within the loop could help in similar circumstances, but unfortunately, I need to log the output after the loop and backdrop based on that, so this solution doesn't seem viable here.
To be more specific, I run into memory issues when num_passes exceeds 20, are there any other ways in which I could fully optimise memory allocation to allow for a greater number of passes? I’m not at all concerned with readability/ ugly solutions, any advice will be of great help.