This is a good question. In my opinion, this is particularly important in order to fully grasp this feature of PyTorch. Which is paramount when dealing with complex setups, whether it involves multiple backward passes or partial backward passes.
In both examples your computational graph is:
y ---------------------------->|
b ----------->| |
w ------->| |
x --> x @ w + b = z --> BCE(z, y) = loss
However, the "computational graph" as we call it is just a representation of the dependencies that exist in the computation of that result. The way this result is tied to the tensors that lead to the final computation, i.e. the intermediate results of the graph. When you compute loss
, a link remains between loss
and all other tensors, this is needed in order to compute the backward pass.
First scenario
In your first example you compute loss
, which by itself creates a "computational graph". Notice the grad_fn
attribute appearing on your loss
variable. This is the callback function used to navigate back up the graph. In your case F.binary_cross_entropy_with_logits
will output a grad_fn=<BinaryCrossEntropyWithLogitsBackward>
. This being said, you successfully compute the backward pass by calling backward()
, doing so backpropagates up the graph using the graph_fn
's functions and updating the parameters' grad
attribute. Then you define loss
using the same z
, the one that is tied to the previous graph. You're essentially going from the previous computational graph above to the following one:
y ---------------------------->|
b ----------->| |
w ------->| |
x --> x @ w + b = z --> BCE(z, y) = loss
\--> BCE(z, y) = loss # 2nd definition of loss
The second definition of loss
overwrites the previous value for loss
, yes. However, it won't affect the first portion of the graph which still exists: as I explained z
is still tied to the initial tensors x
, w
, and b
.
By default, during a backward pass, the activations are freed. This means you won't be able to perform a second pass. To sum up your first example, the second loss.backward()
will go through loss
's (the new one) grad_fn
, then reach the initial z
whose activations have already been freed. This results in the error you've encountered:
Trying to backward pass through the graph a second time
Second scenario
In the second example, you redefine the whole network by recomputing z
from the leaf tensor x
and consequently loss
with intermediate output z
and leaf tensor y
.
Conceptually, the state of the computation graphs is:
y ---------------------------->|
b ----------->| |
w ------->| |
x --> x @ w + b = z --> BCE(z, y) = loss
\-> x @ w + b = z --> BCE(z, y) = loss # 2nd definition of loss
This means that by calling loss.backward
a first time you do a backward pass on the initial graph. Then, after having redefined both z
and loss
, you end up creating a new graph altogether: 2nd branch of the illustration above. The 2nd backward pass ends up working since you're not on the same graph.