I need to compute the first- and second-order derivatives of a vectorial function output
in batch x nvarout
(say 150) with respect to the input x
in batch x nvarin
(say 2).
I manage to do this with the following code:
def continuous_diff(x, y):
torch.set_grad_enabled(True)
x.requires_grad_(True)
# x in [N,nvarin]
# y in [N,nvarout]
# dy in [N,nvarin]
dy_dx = torch.autograd.grad(
y, x, torch.ones_like(y),
retain_graph=True, create_graph=True,)[0]
return dy_dx
and
for k in range(output.shape[1]):
y = output[:,k]
dx = continuous_diff(x,y)
# hardcoded for nvarin = 2 here
dxx = continuous_diff(x, dx[:,0])
dyy = continuous_diff(x, dx[:,1])
grad2 = torch.concatenate([dxx, dyy], dim=-1)
output_grad2.append(grad2)
output_grad2 = torch.stack(output_grad2, dim=-1)
Is there a way to speed up this computation?
A similar question has been posted here, but no solutions have been proposed since 2 years.