The above solution is not totally correct. It's only correct in a special case where output dimension is 1.
As mentioned in the docs, the output of torch.autograd.grad
is related to derivatives but it's not actually dy/dx
. For example, assume you have a neural network that inputs a tensor of shape (batch_size, input_dim)
and outputs a tensor with shape (batch_size, output_dim)
. The derivatives of the output w.r.t. input should be of shape (batch_size, output_dim, input_dim)
but what you get from torch.autograd.grad
has shape (batch_size, input_dim)
, which is the sum of the real derivatives over the output dimension. If you want the correct derivatives you should use torch.autograd.functional.jacobian
as follows:
import torch
torch.>>> torch.__version__
'1.10.1+cu111'
>>>
#!/usr/bin/env python
# coding: utf-8
import torch
from torch import nn
import numpy as np
batch_size = 10
hidden_dim = 20
input_dim = 3
output_dim = 2
model = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, output_dim)).double()
x = torch.rand(batch_size, input_dim, requires_grad=True, dtype=torch.float64) #(batch_size, input_dim)
y = model(x) #y: (batch_size, output_dim)
#using torch.autograd.grad
dydx1 = torch.autograd.grad(y, x, retain_graph=True, grad_outputs=torch.ones_like(y))[0] #dydx1: (batch_size, input_dim)
print(f' using grad dydx1: {dydx1.shape}')
#using torch.autograd.functional.jacobian
j = torch.autograd.functional.jacobian(lambda t: model(t), x) #j: (batch_size, output_dim, batch_size, input_dim)
#the off-diagonal elements of 0th and 2nd dimension are all zero. So we remove them
dydx2 = torch.diagonal(j, offset=0, dim1=0, dim2=2) #dydx2: (output_dim, input_dim, batch_size)
dydx2 = dydx2.permute(2, 0, 1) #dydx2: (batch_size, output_dim, input_dim)
print(f' using jacobian dydx2: {dydx2.shape}')
#round to 14 decimal digits to avoid noise
print(np.round((dydx2.sum(dim=1)).numpy(), 14) == np.round(dydx1.numpy(), 14))
Output:
>using grad dydx1: torch.Size([10, 3])
>using jacobian dydx2: torch.Size([10, 2, 3])
#dydx2.sum(dim=1) == dydx1
>[[ True True True]
[ True True True]
[ True True True]
[ True True True]
[ True True True]
[ True True True]
[ True True True]
[ True True True]
[ True True True]
[ True True True]]
In fact autograd.grad
returns the sum of the dydx
over output dimension.
If you really want to use torch.autograd.grad
there is an inefficient way to do that:
dydx3 = torch.tensor([], dtype=torch.float64)
for i in range(output_dim):
l = torch.zeros_like(y)
l[:, i] = 1.
d = torch.autograd.grad(y, x, retain_graph=True, grad_outputs=l)[0] #dydx: (batch_size, input_dim)
dydx3 = torch.concat((dydx3, d.unsqueeze(dim=1)), dim=1)
print(f' dydx3: {dydx3.shape}')
print(np.round(dydx3.numpy(), 14) == np.round(dydx2.numpy(), 14))
Output:
dydx3: torch.Size([10, 2, 3])
[[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]
[[ True True True]
[ True True True]]]
I hope it helps.
P.S. I used retain_graph=True
because of multiple backward calls.