2

To be clear, I am not

  • Asking how to prevent gradients from being propagated to certain tensors (in this case you can just set requires_grad = False for that tensor).
  • Asking how to prevent gradients from being propagated from an entire tensor (in that case you can just call tensor.detach(), see this question).

I'm wondering how to forgo gradient computations for some elements of a loss tensor that give a NaN gradient every time -- essentially, to call .detach() for individual elements of a tensor. The way to do this in Tensorflow is using tf.stop_gradients, see this question.

Some context: My neural network computes a distance matrix of its predicted coordinates, as follows. The entries of the distance matrix D are given by d_ij = || coordinates_i - coordinates_j ||. I want to backpropagate through the distance matrix creation step. However, the norm function includes a square root, which is not differentiable at 0 -- and the diagonal of the distance matrix is 0 by construction. Thus I get NaN gradients for the diagonal of the distance matrix. I would like to mask out the gradients on the diagonal of the distance matrix.

Minimal working example:

import torch

def compute_distance_matrix(coordinates):
    L = len(coordinates)
    gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
    gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
    # gram_diag: L
    diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
    # diag_1: L x L
    diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
    # diag_2: L x L
    distance_matrix = torch.sqrt(diag_1 + diag_2 - (2 * gram_matrix))
    return distance_matrix

# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])

for i in range(500):
    pred_distance_matrix = compute_distance_matrix(pred_coordinates)
    true_distance_matrix = compute_distance_matrix(true_coordinates)
    loss = obj(pred_distance_matrix, true_distance_matrix)
    loss.backward()
    print(loss.item())
    optimizer.step()

gives

1.2868314981460571
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
...
Jacob Stern
  • 3,758
  • 3
  • 32
  • 54

1 Answers1

2

I initialized a new matrix and used a mask to copy the values with differentiable gradients from the previous tensor (in this case, the non-diagonal entries), then applied the not-everywhere-differentiable operation (the square root) to the new tensor. This allowed the gradient to only flow back through the entries that had a positive mask.

import torch

def compute_distance_matrix(coordinates):
    # In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
    L = len(coordinates)
    gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
    gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
    # gram_diag: L
    diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
    # diag_1: L x L
    diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
    # diag_2: L x L
    squared_distance_matrix = diag_1 + diag_2 - (2 * gram_matrix)
    distance_matrix = torch.zeros_like(squared_distance_matrix)
    mask = ~torch.eye(L, dtype=torch.bool).to(coordinates.device)
    distance_matrix[mask] = torch.sqrt( squared_distance_matrix.masked_select(mask) )
    return distance_matrix

# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])

for i in range(500):
    pred_distance_matrix = compute_distance_matrix(pred_coordinates)
    true_distance_matrix = compute_distance_matrix(true_coordinates)
    loss = obj(pred_distance_matrix, true_distance_matrix)
    loss.backward()
    print(loss.item())
    optimizer.step()

which gives:

1.222102403640747
1.2191187143325806
1.2162436246871948
1.2133947610855103
1.210543155670166
1.2076761722564697
1.204787015914917
1.2018715143203735
1.198927402496338
1.1959534883499146
1.1929489374160767
1.1899129152297974
1.1868458986282349
1.1837480068206787
1.180619239807129
1.1774601936340332
1.174271583557129
...
Jacob Stern
  • 3,758
  • 3
  • 32
  • 54