0

Why doesn't this code run for nsamples=3? It runs for nsamples in (1,2).

from scipy.stats import multivariate_normal
import numpy as np

mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)

nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.meshgrid(x, y)
print state
rv.logpdf(state)

Here's the error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-595249b070ac> in <module>()
      4 state = np.meshgrid(x, y)
      5 print state
----> 6 rv.logpdf(state)

/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in logpdf(self, x)
    518         x = _process_quantiles(x, self.dim)
    519         out = self._mnorm._logpdf(x, self.mean, self.cov_info.U,
--> 520                                   self.cov_info.log_pdet, self.cov_info.rank)
    521         return _squeeze_output(out)
    522 

/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in _logpdf(self, x, mean, prec_U, log_det_cov, rank)
    377 
    378         """
--> 379         dev = x - mean
    380         maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
    381         return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)

ValueError: operands could not be broadcast together with shapes (2,3,3) (2,) 

Seems like there's a bug in the library: I think either x needs to roll an axis or mean needs to be reshaped.

Steve Schulist
  • 931
  • 1
  • 11
  • 18

1 Answers1

1

np.meshgrid returns a tuple of 2D arrays:

In [124]: np.meshgrid(x, y)
Out[124]: 
[array([[-1.,  0.,  1.],
        [-1.,  0.,  1.],
        [-1.,  0.,  1.]]), array([[-2., -2., -2.],
        [ 0.,  0.,  0.],
        [ 2.,  2.,  2.]])]

rv.logpdf expects an list of 2-tuples, or an array whose last axis has length 2:

In [128]: state
Out[128]: 
array([[-1., -2.],
       [-1.,  0.],
       [-1.,  2.],
       [ 0., -2.],
       [ 0.,  0.],
       [ 0.,  2.],
       [ 1., -2.],
       [ 1.,  0.],
       [ 1.,  2.]])

In [129]: rv.logpdf(state)
Out[129]: 
array([-3.53102425, -3.03102425, -3.53102425, -3.03102425, -2.53102425,
       -3.03102425, -3.53102425, -3.03102425, -3.53102425])

In [131]: rv.logpdf(state.reshape(3,3,-1))
Out[131]: 
array([[-3.53102425, -3.03102425, -3.53102425],
       [-3.03102425, -2.53102425, -3.03102425],
       [-3.53102425, -3.03102425, -3.53102425]])

So instead of np.meshgrid you could use itertools.product:

state = np.array(list(IT.product(x, y)))

or, for better speed when x and y are large, use pv's cartesian function.


from scipy.stats import multivariate_normal
import numpy as np
import itertools as IT

mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)

nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.array(list(IT.product(x, y)))
logpdf = rv.logpdf(state.reshape(nsamples, nsamples, -1))
print(logpdf)

yields

[[-3.53102425 -3.03102425 -3.53102425]
 [-3.03102425 -2.53102425 -3.03102425]
 [-3.53102425 -3.03102425 -3.53102425]]
Community
  • 1
  • 1
unutbu
  • 842,883
  • 184
  • 1,785
  • 1,677