7

I have a special use case that I have to separate inference and back-propagation: I have to inference all images and slice outputs into batches followed by back-propagating batches by batches. I don't need to update my network's weights.

I modified snippets of cifar10_tutorial into the following to simulate my problem: j is a variable to represent the index which returns by my own logic and I want the gradient of some variables.

for epoch in range(2):  # loop over the dataset multiple times

    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs.requires_grad = True

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)

        for j in range(4): # j is given by external logic in my own case

            loss = criterion(outputs[j, :].unsqueeze(0), labels[j].unsqueeze(0))

            loss.backward()

            print(inputs.grad.data[j, :]) # what I really want

I got the following errors:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

My questions are:

  1. According to my understanding, the problem arises because the first back-propagate backwards the whole outputs and outputs[1,:].unsqueeze(0) was released so second back-propagate failed. Am I right?

  2. In my case, if I set retain_graph=True, will the code run slower and slower according to this post?

  3. Is there better way to achieve my goal?

NotNotLogic
  • 63
  • 2
  • 8
Tengerye
  • 1,796
  • 1
  • 23
  • 46

1 Answers1

3
  1. Yes you are correct. When you already back-propagated through outputs the first time (first iteration), the buffers will be freed and it will fail the following time (next iteration of your loop), because then necessary data for this computation have already been removed.

  2. Yes, the graph grows bigger and bigger, so it could be slower depending on GPU (or CPU) usage and your network. I had used this once and it was much slower, however this depends much on your network architecture. But certainly you will need more memory with retain_graph=True than without.

  3. Depending on your outputs and labels shape you should be able to calculate the loss for all your outputs and labels at once:

    criterion(outputs, labels)
    

    You have to skip the j-loop then, this would also make your code faster. Maybe you need to reshape (resp. view) your data, but this should work fine.

    If you for some reason cannot do that you can manually sum up the loss on a tensor and call backward after the loop. This should work fine too, but is slower than the solution above.

    So than your code would look like this:

    # init loss tensor
    loss = torch.tensor(0.0) # move to GPU if you're using one
    
    for j in range(4):
        # summing up your loss for every j
        loss += criterion(outputs[j, :].unsqueeze(0), labels[j].unsqueeze(0))
        # ...
    # calling backward on the summed loss - getting gradients
    loss.backward()
    # as you call backward now only once on the outputs
    # you shouldn't get any error and you don't have to use retain_graph=True
    

Edit:

The accumulation of the losses and calling later backward is completely equivalent, here is a small example with and without accumulating the losses:

First creating some data data:

# w in this case will represent a very simple model
# I leave out the CE and just use w to map the output to a scalar value
w = torch.nn.Linear(4, 1)
data = [torch.rand(1, 4) for j in range(4)]

data looks like:

[tensor([[0.4593, 0.3410, 0.1009, 0.9787]]),
 tensor([[0.1128, 0.0678, 0.9341, 0.3584]]),
 tensor([[0.7076, 0.9282, 0.0573, 0.6657]]),
 tensor([[0.0960, 0.1055, 0.6877, 0.0406]])]

Let's first do like you're doing it, calling backward for every iteration j separately:

# code for directly applying backward
# zero the weights layer w
w.zero_grad()
for j, inp in enumerate(data):
    # activate grad flag
    inp.requires_grad = True
    # remove / zero previous gradients for inputs
    inp.grad = None
    # apply model (only consists of one layer in our case)
    loss = w(inp)
    # calling backward on every output separately
    loss.backward()
    # print out grad
    print('Input:', inp)
    print('Grad:', inp.grad)
    print()
print('w.weight.grad:', w.weight.grad)

Here is the print-out with every input and the respective gradient and gradients for the model resp. layer w in our simplified case:

Input: tensor([[0.4593, 0.3410, 0.1009, 0.9787]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.1128, 0.0678, 0.9341, 0.3584]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.7076, 0.9282, 0.0573, 0.6657]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.0960, 0.1055, 0.6877, 0.0406]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

w.weight.grad: tensor([[1.3757, 1.4424, 1.7801, 2.0434]])

Now instead of calling backward once for every iteration j we accumulate the values and call backward on the sum and compare the results:

# init tensor for accumulation
loss = torch.tensor(0.0)
# zero layer gradients
w.zero_grad()
for j, inp in enumerate(data):
    # activate grad flag
    inp.requires_grad = True
    # remove / zero previous gradients for inputs
    inp.grad = None
    # apply model (only consists of one layer in our case)
    # accumulating values instead of calling backward
    loss += w(inp).squeeze()
# calling backward on the sum
loss.backward()

# printing out gradients 
for j, inp in enumerate(data):
    print('Input:', inp)
    print('Grad:', inp.grad)
    print()
print('w.grad:', w.weight.grad)

Lets take a look at the results:

Input: tensor([[0.4593, 0.3410, 0.1009, 0.9787]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.1128, 0.0678, 0.9341, 0.3584]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.7076, 0.9282, 0.0573, 0.6657]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

Input: tensor([[0.0960, 0.1055, 0.6877, 0.0406]], requires_grad=True)
Grad: tensor([[-0.0999,  0.2665, -0.1506,  0.4214]])

w.grad: tensor([[1.3757, 1.4424, 1.7801, 2.0434]])

When comparing the results we can see that both are the same.
This is a very simple example, but nevertheless we can see that calling backward() on every single tensor and summing up tensors and then calling backward() is equivalent in terms of the resulting gradients for both inputs and weights.

When you use CE for all j 's at once as described in 3. you can use the flag reduction='sum' to archive the same behaviour like above with summing up the CE values, default is ‘mean’, which probably leads to slightly different results.

MBT
  • 21,733
  • 19
  • 84
  • 102
  • Thank you for your quick and excellent reply. But in my case, I don't need loss at all, I only care about gradients. So do I have the only solution `retain_graph=True`? – Tengerye Dec 19 '18 at 11:32
  • Did you try what I've written at ***3.***? This should still apply for you, regardless what you do later after calling `backward` on the `loss`-tensor. – MBT Dec 19 '18 at 13:33
  • Sincerely apologize for the confusion, would you please have a look at the example that I edited now? – Tengerye Dec 20 '18 at 01:10
  • @Tengerye I wonder why you wan't to create gradients w.r.t. your input data? But still I don't see why you can't call `backward` once for every iteration *i* like I described in ***3.***. Did you try? If so, why didn't it work? – MBT Dec 20 '18 at 15:44
  • I clarified my question again. Sorry for the late reply. I am studying about attacking deep learning models so I care about the gradients w.r.t. the input data instead of weights. Your kind reply does remind me of back-warding only once and I am still working on it. Sincerely thank you again. @blue-phoenox – Tengerye Dec 22 '18 at 08:54
  • I am not sure if the loss is still correct when I try to backward on the loss of a specific `j`. – Tengerye Dec 22 '18 at 08:57
  • @Tengerye I made an edit, there you can see that it is equivalent. – MBT Dec 22 '18 at 16:06