0

I am trying to compute multiple loss gradients efficiently (without a for loop) in PyTorch. Given:

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(input_size, 16, bias=False),
            nn.Linear(16, output_size, bias=False),
        )

    def forward(self, x):
        return self.linear(x)
    
device = "cpu"
input_size = 2
output_size = 2

x = torch.randn(10, 1, input_size).to(device)
y = torch.randn(10, 1, output_size).to(device)

model = NeuralNetwork().to(device)
loss_fn = nn.MSELoss()

def loss_grad(x, label):
    y = model(x)
    loss = loss_fn(y, label)
    grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
    return grads

The following works, but uses a for loop:

# inefficient but works
def compute_for():
    grads = [loss_grad(x[i], y[i]) for i in range(x.shape[0])]
    print(grads)

compute_for()

For efficiency, I tried using torch.vmap instead:

# potentially more efficient but doesn't work
def compute_vmap():
    grads = torch.vmap(loss_grad)(x, y)
    print(grads)

compute_vmap()

I was expecting it to compute the gradients of the losses w.r.t. the parameters for each element in x, y. Instead, I get an error:

RuntimeError: element 0 of tensors does not require grad

As I understand, this means that elements from the tensor x will be computed and they don't individually require grad.

How can I modify this code so that it computes all gradients? Or is there another method to do that?

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
Starfree
  • 3
  • 3
  • What are "multiple loss gradients"? What's wrong with using the standard approach for computing gradients by running `loss.backward()`? – Mateen Ulhaq May 28 '23 at 22:50
  • `loss.backward()` computes gradients with respect to a single loss value. However, I want to compute loss gradient for each element in the input tensor x and the corresponding target tensor y. Without a for loop for efficiency reasons. – Starfree May 29 '23 at 05:04
  • Why not take a look here https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch – Edwin Cheong May 29 '23 at 05:45

1 Answers1

0

The per-sample gradients may be computed using vmap as shown in the relevant tutorial:

from torch.func import functional_call, vmap, grad

def compute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)
    predictions = functional_call(model, (params, buffers), (batch,))
    loss = loss_fn(predictions, targets)
    return loss

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x, y)

print(ft_per_sample_grads)

These match the gradients computed individually for each pair (x[i], y[i]).

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135