3

I'm a student and a beginner in Python and PyTorch both. I have a very basic Neural Network for which I am encountering the mentioned RunTimeError. The code to reproduce the error is this:

import torch 
from torch import nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Ensure Reproducibility
torch.manual_seed(0)

# Data Generation
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
# Shuffles the indices
idx = np.arange(100)
np.random.shuffle(idx)

# Uses first 80 random indices for train
train_idx = idx[:70]
# Uses the remaining indices for validation
val_idx = idx[70:]

# Generates train and validation sets
x_train, y_train = x[train_idx], y[train_idx]
x_val, y_val = x[val_idx], y[val_idx]

class OurFirstNeuralNetwork(nn.Module):
    def __init__(self):
        super(OurFirstNeuralNetwork, self).__init__()
        # Here we "define" our Neural Network Architecture
        self.fc1 = nn.Linear(1, 5)
        self.non_linearity_fc1 = nn.ReLU()
        self.fc2 = nn.Linear(5,1)
        #self.non_linearity_fc2 = nn.ReLU()

    def forward(self, x):
        # The forward pass
        # Here we define how activations "flow" between neurons. We've already discussed the "Sum" and "Transformation" steps of the forward pass.
        sum_fc1 = self.fc1(x)
        transformation_fc1 = self.non_linearity_fc1(sum_fc1)
        sum_fc2 = self.fc2(transformation_fc1)
        #transformation_fc2 = self.non_linearity_fc2(sum_fc2)
        # The transformation_fc2 is also the output of our model which symbolises the end of our forward pass. 
        return sum_fc2

# Instantiate the model and train

model = OurFirstNeuralNetwork()
print(model)
print(model.state_dict())
n_epochs = 1000
loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters())

for epoch in range(n_epochs):


    model.train()
    optimizer.zero_grad()
    prediction = model(x_train)
    loss = loss_fn(y_train, prediction)
    print(epoch, loss)
    loss.backward(retain_graph=True)    
    optimizer.step()


print(model.state_dict())

Everything is basic and standard and this works fine.

However, when I take out the "retain_graph=True" argument, it throws the RunTimeError. From reading various forums, I understand that this is to do with the graph getting thrown away after the first iteration but I have seen many tutorials and blogs where loss.backward() is the way to go especially since it conserves memory. But I am not able to conceptually grasp why the same does not work for me.

Any help is appreciated and my apologies if the way in which I have asked my question is not in the expected format. I am open to feedback and will oblige to include more details or rephrase the question so that it is easier for everyone. Thank you in advance!

Bhargav Desai
  • 83
  • 1
  • 5

1 Answers1

8

You need to add optimizer.zero_grad() after optimizer.step() to zero out the gradients.

Why you need to do this?

When you do loss.backward() torch will compute gradients for parameters and update the parameter's .grad property. When you do optimizer.step(), the parameters are updated using the .grad property as i.e `parameter = parameter - lr*parameter.grad.

Since you do not clear the gradients and call backward the second time, it will compute dl/d(updated param) which will require to backpropagate through paramter.grad of the first pass. When doing backward, the computation graph of this gradients is not stored and hence you have to pass retain_graph= True to get rid of error. However, we don't want to do that for updating params. Rather we want to clear gradients, and restart with a new computation graph therefore, you need to zero the gradients with a .zero_grad call.

Also see Why do we need to call zero_grad() in PyTorch?

Umang Gupta
  • 15,022
  • 6
  • 48
  • 66
  • 2
    Thanks for your explanation Umang it was super and makes absolute perfect sense and I expected this to work. However, it still throws the exact same error without the retain_graph=True argument. With retain_graph=True, it works fine. – Bhargav Desai Apr 05 '20 at 18:50
  • 1
    Is there some particular reason to have `requires_grad=True` for input `x`? That is causing the problem here actually. As you are trying to backpropagate multiple times through that without clearing gradients. – Umang Gupta Apr 05 '20 at 23:05
  • 1
    I set it to true because otherwise 'grad_fn' attribute will be none and I though that is needed so that the autograd can do its thing. Please correct me if I'm wrong. – Bhargav Desai Apr 06 '20 at 15:27
  • 1
    So `.backward` will work without setting that true for input (because you want gradients for parameters and not input). If you want to make it working with setting grad true for input, please also clear gradients of input – Umang Gupta Apr 06 '20 at 15:44
  • 2
    Got it Umang! Thanks a lot for your prompt responses and clear explanations! – Bhargav Desai Apr 06 '20 at 19:37