0

I want to minimize a loss function of a symmetric matrix where some values are fixed. To do this, I defined the tensor A_nan and I placed objects of type torch.nn.Parameter in the values to estimate.

However, when I try to run the code I get the following exception:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I found this question that seemed to have the same problem, but the solution proposed there does not apply to my case (as far as I understand). Or at least I would not know how to apply it.

Here is a self-contained example of what I am trying to do:

import torch

A_nan = torch.tensor([[1.0, 2.0, torch.nan], [2.0, torch.nan, 5.0], [torch.nan, 5.0, 6.0]])
nan_idxs = torch.where(torch.isnan(torch.triu(A_nan)))
A_est = torch.clone(A_nan)
weights = torch.nn.ParameterList([])
for i, j in zip(*nan_idxs):
    w = torch.nn.Parameter(torch.distributions.Normal(3, 0.5).sample())
    A_est[i, j] = w
    A_est[j, i] = w
    weights.append(w)

optimizer = torch.optim.Adam(weights, lr=0.01)
for _ in range(10):
    optimizer.zero_grad()
    loss = torch.sum(A_est ** 2)
    loss.backward()
    optimizer.step()
Tendero
  • 1,136
  • 2
  • 19
  • 34

1 Answers1

1

The computation graph get destroyed after calling loss.backward(). Note that A_est is not a leaf, so there isn't any creating of the graph in the second pass of the loop.

One possible solution would be:

for _ in range(2):
    optimizer.zero_grad()
    A_est = A_est.detach()
    for idx, (i, j) in enumerate(zip(*nan_idxs)):
      w = weights[idx]
      A_est[i, j] = w
      A_est[j, i] = w
    loss = torch.sum(A_est**2)
    loss.backward()
    optimizer.step()
    print(loss)
    # tensor(110.3664, grad_fn=<SumBackward0>)
    # tensor(109.0410, grad_fn=<SumBackward0>)
TanjiroLL
  • 1,354
  • 1
  • 5
  • 5