1

If I run the code:

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b

loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

pytorch spits the error "Trying to backward through the graph a second time" at me. My understanding is that calling the loss calculation line again doesn't actually change the computational graph, which is why I get this error. However, when I call the code:

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b

loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

it works fine (without error), and I don't understand why this is the case, in either case, I haven't made any change to the computational graph?

QCD_IS_GOOD
  • 435
  • 7
  • 22

1 Answers1

1

This is a good question. In my opinion, this is particularly important in order to fully grasp this feature of PyTorch. Which is paramount when dealing with complex setups, whether it involves multiple backward passes or partial backward passes.

In both examples your computational graph is:

y ---------------------------->|
b ----------->|                |
w ------->|                    |
x --> x @ w + b = z --> BCE(z, y) = loss

However, the "computational graph" as we call it is just a representation of the dependencies that exist in the computation of that result. The way this result is tied to the tensors that lead to the final computation, i.e. the intermediate results of the graph. When you compute loss, a link remains between loss and all other tensors, this is needed in order to compute the backward pass.

First scenario

In your first example you compute loss, which by itself creates a "computational graph". Notice the grad_fn attribute appearing on your loss variable. This is the callback function used to navigate back up the graph. In your case F.binary_cross_entropy_with_logits will output a grad_fn=<BinaryCrossEntropyWithLogitsBackward>. This being said, you successfully compute the backward pass by calling backward(), doing so backpropagates up the graph using the graph_fn's functions and updating the parameters' grad attribute. Then you define loss using the same z, the one that is tied to the previous graph. You're essentially going from the previous computational graph above to the following one:

y ---------------------------->|
b ----------->|                |
w ------->|                    |
x --> x @ w + b = z --> BCE(z, y) = loss
                   \--> BCE(z, y) = loss # 2nd definition of loss

The second definition of loss overwrites the previous value for loss, yes. However, it won't affect the first portion of the graph which still exists: as I explained z is still tied to the initial tensors x, w, and b.

By default, during a backward pass, the activations are freed. This means you won't be able to perform a second pass. To sum up your first example, the second loss.backward() will go through loss's (the new one) grad_fn, then reach the initial z whose activations have already been freed. This results in the error you've encountered:

Trying to backward pass through the graph a second time

Second scenario

In the second example, you redefine the whole network by recomputing z from the leaf tensor x and consequently loss with intermediate output z and leaf tensor y.

Conceptually, the state of the computation graphs is:

y ---------------------------->|
b ----------->|                |
w ------->|                    |
x --> x @ w + b = z --> BCE(z, y) = loss
  \-> x @ w + b = z --> BCE(z, y) = loss # 2nd definition of loss

This means that by calling loss.backward a first time you do a backward pass on the initial graph. Then, after having redefined both z and loss, you end up creating a new graph altogether: 2nd branch of the illustration above. The 2nd backward pass ends up working since you're not on the same graph.

Ivan
  • 34,531
  • 8
  • 55
  • 100