14

I have implemented the following Jacobian function in pytorch. Unless I have made a mistake, it computes the Jacobian of any tensor w.r.t. any dimensional inputs:

import torch
import torch.autograd as ag

def nd_range(stop, dims = None):
    if dims == None:
        dims = len(stop)
    if not dims:
        yield ()
        return
    for outer in nd_range(stop, dims - 1):
        for inner in range(stop[dims - 1]):
            yield outer + (inner,)


def full_jacobian(f, wrt):    
    f_shape = list(f.size())
    wrt_shape = list(wrt.size())
    fs = []


    f_range = nd_range(f_shape)
    wrt_range = nd_range(wrt_shape)

    for f_ind in f_range:
        grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        for i in range(len(f_shape)):
            grad = grad.unsqueeze(0)
        fs.append(grad)

    fj = torch.cat(fs, dim=0)
    fj = fj.view(f_shape + wrt_shape)
    return fj

On top of this, I have tried to implement a recursive function to calculate nth order derivatives:

def nth_derivative(f, wrt, n):
    if n == 1:
        return full_jacobian(f, wrt)
    else:        
        deriv = nth_derivative(f, wrt, n-1)
        return full_jacobian(deriv, wrt)

I ran a simple test:

op = torch.ger(s, s)
deep_deriv = nth_derivative(op, s, 5)

Unfortunately, this succeeds in getting me the Hessian...but no higher order derivatives. I'm aware many higher order derivatives should be 0, but I'd prefer if pytorch can analytically compute that.

One fix has been to change the gradient calculation to:

try:
            grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        except:
            grad = torch.zeros_like(wrt)

Is this the accepted correct way to handle this? Or is there a better option? Or do I have the reason for my issue completely wrong to begin with?

lurscher
  • 25,930
  • 29
  • 122
  • 185
user650261
  • 2,115
  • 5
  • 24
  • 47

2 Answers2

21

You can just iterate calling the grad function:

import torch
from torch.autograd import grad

def nth_derivative(f, wrt, n):

    for i in range(n):

        grads = grad(f, wrt, create_graph=True)[0]
        f = grads.sum()

    return grads

x = torch.arange(4, requires_grad=True).reshape(2, 2)
loss = (x ** 4).sum()

print(nth_derivative(f=loss, wrt=x, n=3))

outputs

tensor([[  0.,  24.],
        [ 48.,  72.]])
Alex
  • 18,484
  • 8
  • 60
  • 80
  • Not the cleanest solution, but it'll work. Does such an iterative method affect performance? – user650261 May 16 '18 at 17:32
  • 3
    "Does such an iterative method affect performance?" is a very vague question. It may or it may not depending on the rest of your code. Unless you are having performance problems and identify this to be a bottleneck, don't sweat it. To [quote Donald Knuth](http://wiki.c2.com/?PrematureOptimization) "premature optimization is the root of all evil". – Alex May 16 '18 at 17:39
  • 2
    @user650261 Also what do you mean by "Not the cleanest solution"? – Alex May 16 '18 at 17:40
  • Is the main idea behind this approach to simply create a new computational graph every time you take the derivative (and then take the derivative on this new computational graph/function)? – information_interchange Mar 21 '20 at 19:14
  • @information_interchange yes: "`create_graph` (*bool, optional*) – If `True`, graph of the derivative will be constructed, allowing to compute higher order derivative products." – iacob Mar 31 '21 at 10:16
  • 2
    There is a big problem here; surprised no one has mentioned it before. Getting sum and then taking derivative yields the second derivative PLUS the sum of cross-derivatives (and they accumulate more as the for loop continues). In the example above (i.e., x**4), all of those cross-derivatives are zero by construction. – Mehdi Jan 29 '23 at 16:17
2

For the second order derivative, you can use PyTorch's hessian function:

torch.autograd.functional.hessian()

For higher order derivatives, you can repeatedly call jacobian or grad while maintaining the computational graph:

create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products.

iacob
  • 20,084
  • 6
  • 92
  • 119