Why doesn't this code run for nsamples=3
? It runs for nsamples in (1,2)
.
from scipy.stats import multivariate_normal
import numpy as np
mean = np.array([0,0])
covar = np.array([[1,0],[0,4]])
rv = multivariate_normal(mean, covar)
nsamples = 3
x = np.linspace(-1, 1, nsamples)
y = np.linspace(-2, 2, nsamples)
state = np.meshgrid(x, y)
print state
rv.logpdf(state)
Here's the error message:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-11-595249b070ac> in <module>()
4 state = np.meshgrid(x, y)
5 print state
----> 6 rv.logpdf(state)
/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in logpdf(self, x)
518 x = _process_quantiles(x, self.dim)
519 out = self._mnorm._logpdf(x, self.mean, self.cov_info.U,
--> 520 self.cov_info.log_pdet, self.cov_info.rank)
521 return _squeeze_output(out)
522
/appl/pm/vendor/dev/python/lx-x86_64/2.7.9/lib/python2.7/site-packages/scipy/stats/_multivariate.pyc in _logpdf(self, x, mean, prec_U, log_det_cov, rank)
377
378 """
--> 379 dev = x - mean
380 maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
381 return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)
ValueError: operands could not be broadcast together with shapes (2,3,3) (2,)
Seems like there's a bug in the library: I think either x
needs to roll an axis or mean
needs to be reshaped.