To be clear, I am not
- Asking how to prevent gradients from being propagated to certain tensors (in this case you can just set
requires_grad = False
for that tensor). - Asking how to prevent gradients from being propagated from an entire tensor (in that case you can just call
tensor.detach()
, see this question).
I'm wondering how to forgo gradient computations for some elements of a loss tensor that give a NaN gradient every time -- essentially, to call .detach()
for individual elements of a tensor. The way to do this in Tensorflow is using tf.stop_gradients
, see this question.
Some context: My neural network computes a distance matrix of its predicted coordinates, as follows. The entries of the distance matrix D are given by d_ij = || coordinates_i - coordinates_j ||
. I want to backpropagate through the distance matrix creation step. However, the norm function includes a square root, which is not differentiable at 0 -- and the diagonal of the distance matrix is 0 by construction. Thus I get NaN gradients for the diagonal of the distance matrix. I would like to mask out the gradients on the diagonal of the distance matrix.
Minimal working example:
import torch
def compute_distance_matrix(coordinates):
L = len(coordinates)
gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
# gram_diag: L
diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
# diag_1: L x L
diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
# diag_2: L x L
distance_matrix = torch.sqrt(diag_1 + diag_2 - (2 * gram_matrix))
return distance_matrix
# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])
for i in range(500):
pred_distance_matrix = compute_distance_matrix(pred_coordinates)
true_distance_matrix = compute_distance_matrix(true_coordinates)
loss = obj(pred_distance_matrix, true_distance_matrix)
loss.backward()
print(loss.item())
optimizer.step()
gives
1.2868314981460571
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
...