1

While using Multivariate normal distribution in PyTorch I decided to compare it with exact analytical expression.

To my surprise, there was a small difference between them.

Is there any reason for this behaviour?

Firstly, calculate probabilities using MultivariateNormal:

1 from torch.distributions.multivariate_normal import MultivariateNormal
2 import torch
3 sigma = 2
4 m = MultivariateNormal(torch.zeros(2, dtype=torch.float32), torch.eye(2, dtype=torch.float32)*sigma**2)
5 values_temp = torch.zeros(size=(1,2), dtype=torch.float32)
6 out_torch = torch.exp(m.log_prob(values_temp))
7 out_torch 
Out: tensor([0.0398])

Secondly, one can write exact formula for this case:

1 import numpy as np
2 out_exact = 1/(2*np.pi*sigma**2) * torch.exp(-torch.pow(values_temp, 2).sum(dim=-1)/(2*sigma**2))
3 out_exact
Out: tensor([0.0398])

There is a difference between them:

1 (out_torch - out_exact).sum()
Out: tensor(3.7253e-09)

Can someone help me understand the behavior of these two snippets? Which of these two expressions is more precise? Maybe someone can underline my mistake in any part of the code?

Oiale
  • 434
  • 4
  • 17

1 Answers1

2

Most modern systems use the IEEE 754 standard to represent fixed precision floating point values. Because of this, we can be sure that neither the result provided from pytorch or the "exact" value you've computed are actually exactly equal to the analytical expression. We know this because the actual value of the expression is certainly irrational, and IEEE 754 cannot exactly represent any irrational number. This is a general phenomenon when using fixed precision floating point representations, which you can read more about on Wikipedia and this question.

Upon further analysis we find that the normalized difference you're seeing is on the order of machine epsilon (i.e. 3.7253e-09 / 0.0398 is approximately equal to torch.finfo(torch.float32).eps) indicating that the difference is likely just a result of inaccuracies of floating point arithmetic.

For a further demonstration we can write a mathematically equivalent expression to the one you have as

out_exact = torch.exp(np.log(1/ (2*np.pi*sigma**2)) + (-torch.pow(values_temp, 2).sum(dim=-1)/2/sigma**2))

which agrees exactly with value given by my current installation of pytorch.

jodag
  • 19,885
  • 5
  • 47
  • 66