5

I am trying to visualise a multivariate normal distribution with matplotlib. I would like to produce something like this:

enter image description here

I use the following code:

from mpl_toolkits import mplot3d
x = np.linspace(-1, 3, 100)
y = np.linspace(0, 4, 100)
X, Y = np.meshgrid(x, y)
Z = np.random.multivariate_normal(mean = [1, 2], cov = np.array([[0.5, 0.25],[0.25, 0.50]]), size = 100000)
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')
ax.set_title('surface');

But I get the following error message:

...
      7 ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
----> 8                 cmap='viridis', edgecolor='none')
...
ValueError: shape mismatch: objects cannot be broadcast to a single shape

What is the reason of the error and how my code could be corrected?

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
user8270077
  • 4,621
  • 17
  • 75
  • 140
  • 1
    Shouldn't `Z` be a function of `x, y` pairs? Like illustrated [here](https://stackoverflow.com/a/9170879/8881141). Atm you probably created a different array size for Z. – Mr. T Jan 26 '18 at 16:51

1 Answers1

7

I have done this with scipy.stats.multivariate_normal, using the pdf method to generate the z values. As @Piinthesky pointed out, the numpy implementation returns the x and y values for a given distribution. An example using the scipy version is found in Python add gaussian noise in a radius around a point [closed]:

Tested in python 3.11.2, scipy 1.10.1, matplotlib 3.7.1, numpy 1.24.3

import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import numpy as np

x = np.linspace(-1, 3, 100)
y = np.linspace(0, 4, 100)
X, Y = np.meshgrid(x, y)
pos = np.dstack((X, Y))
mu = np.array([1, 2])
cov = np.array([[.5, .25],[.25, .5]])
rv = multivariate_normal(mu, cov)
Z = rv.pdf(pos)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z)
fig.show()

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
Grr
  • 15,553
  • 7
  • 65
  • 85