2

How to compute the Hessian matrix of a large neural network or transformer model like BERT in PyTorch? I know torch.autograd.functional.hessian, but it seems like it only calculates the Hessian of a function, but not a neural network. I also saw the answer in How to compute hessian matrix for all parameters in a network in pytorch?. The problem is, I want to compute the Hessian with respect to the weights, but for large neural networks, it is very inefficient to write it as a function of the weights. Is there a better way to do this? Any suggestion is appreciated. Thanks.

Yan Pan
  • 21
  • 3

1 Answers1

0

After sometime I finally found a new feature in pytorch nightly build that solves this problem. The details are described in this comment: https://github.com/pytorch/pytorch/issues/49171#issuecomment-933814662. The solution uses the function torch.autograd.functional.hessian and the new feature torch.nn.utils._stateless. Notice that you have to install the nightly version of pytorch to use this new feature.

Yan Pan
  • 21
  • 3