I'm currently trying to implement an ODE Solver with Pytorch, my solution requires computing the gradient of each output wtr to its input.
y = model(x)
for i in range(len(y)): #compute output grad wrt input
y[i].backward(retain_graph=True)
ydx=x.grad
I was wondering if there is a more elegant way to compute the gradients for each output in the batch, since the code gets messy for higher order ODEs and PDEs. I tried using:
torch.autograd.backward(x,y,retain_graph=True)
without much success.