9

What is the most efficient way to plot 3d array in Python?

For example:

volume = np.random.rand(512, 512, 512)

where array items represent grayscale color of each pixel.


The following code works too slow:

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.gca(projection='3d')
volume = np.random.rand(20, 20, 20)
for x in range(len(volume[:, 0, 0])):
    for y in range(len(volume[0, :, 0])):
        for z in range(len(volume[0, 0, :])):
            ax.scatter(x, y, z, c = tuple([volume[x, y, z], volume[x, y, z], volume[x, y, z], 1]))
plt.show()
Dmitry
  • 14,306
  • 23
  • 105
  • 189
  • If you want to navigate through the data rather than plot at once, this could serve: https://www.datacamp.com/community/tutorials/matplotlib-3d-volumetric-data – ferdymercury Apr 15 '21 at 21:14

3 Answers3

6

For better performance, avoid calling ax.scatter multiple times, if possible. Instead, pack all the x,y,z coordinates and colors into 1D arrays (or lists), then call ax.scatter once:

ax.scatter(x, y, z, c=volume.ravel())

The problem (in terms of both CPU time and memory) grows as size**3, where size is the side length of the cube.

Moreover, ax.scatter will try to render all size**3 points without regard to the fact that most of those points are obscured by those on the outer shell.

It would help to reduce the number of points in volume -- perhaps by summarizing or resampling/interpolating it in some way -- before rendering it.

We can also reduce the CPU and memory required from O(size**3) to O(size**2) by only plotting the outer shell:

import functools
import itertools as IT
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def cartesian_product_broadcasted(*arrays):
    """
    http://stackoverflow.com/a/11146645/190597 (senderle)
    """
    broadcastable = np.ix_(*arrays)
    broadcasted = np.broadcast_arrays(*broadcastable)
    dtype = np.result_type(*arrays)
    rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
    out = np.empty(rows * cols, dtype=dtype)
    start, end = 0, rows
    for a in broadcasted:
        out[start:end] = a.reshape(-1)
        start, end = end, end + rows
    return out.reshape(cols, rows).T

# @profile  # used with `python -m memory_profiler script.py` to measure memory usage
def main():
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, projection='3d')

    size = 512
    volume = np.random.rand(size, size, size)
    x, y, z = cartesian_product_broadcasted(*[np.arange(size, dtype='int16')]*3).T
    mask = ((x == 0) | (x == size-1) 
            | (y == 0) | (y == size-1) 
            | (z == 0) | (z == size-1))
    x = x[mask]
    y = y[mask]
    z = z[mask]
    volume = volume.ravel()[mask]

    ax.scatter(x, y, z, c=volume, cmap=plt.get_cmap('Greys'))
    plt.show()

if __name__ == '__main__':
    main()

enter image description here

But note that even when plotting only the outer shell, to achieve a plot with size=512 we still need around 1.3 GiB of memory. Also beware that even if you have enough total memory but, due to a lack of RAM, the program uses swap space, then the overall speed of the program will slow down dramatically. If you find yourself in this situation, then the only solution is to find a smarter way to render an acceptable image using fewer points, or to buy more RAM.

enter image description here

unutbu
  • 842,883
  • 184
  • 1,785
  • 1,677
  • This code works fast enough for arrays of size up to (100, 100, 100). Can I expect that the faster ways to build x, y, z can help to make it work for (512, 512, 512) arrays? – Dmitry Aug 30 '17 at 22:29
  • 1
    For an array of that size I think the main bottleneck is the rendering of 512**3 (around 134 million) points, not the creation of the coordinate arrays. I've edited the post above to show how you could plot only the outer shell of the cube. That reduces both the CPU and memory complexity from `O(size**3)` to `O(size**2)`. Still, depending on your machine's speed and resources, this could take a considerable amount of time to render. – unutbu Aug 31 '17 at 12:43
  • Most of the items in the arrays are zeros. So, it should show same shape inside the cube in result... – Dmitry Aug 31 '17 at 20:33
5

First, a dense grid of 512x512x512 points is way too much data to plot, not from a technical perspective but from being able to see anything useful from it when observing the plot. You probably need to extract some isosurfaces, look at slices, etc. If most of the points are invisible, then it's probably okay, but then you should ask ax.scatter to only show the nonzero points to make it faster.

That said, here's how you can do it much more quickly. The tricks are to eliminate all Python loops, including ones that would be hidden in libraries like itertools.

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt

# Make this bigger to generate a dense grid.
N = 8

# Create some random data.
volume = np.random.rand(N, N, N)

# Create the x, y, and z coordinate arrays.  We use 
# numpy's broadcasting to do all the hard work for us.
# We could shorten this even more by using np.meshgrid.
x = np.arange(volume.shape[0])[:, None, None]
y = np.arange(volume.shape[1])[None, :, None]
z = np.arange(volume.shape[2])[None, None, :]
x, y, z = np.broadcast_arrays(x, y, z)

# Turn the volumetric data into an RGB array that's
# just grayscale.  There might be better ways to make
# ax.scatter happy.
c = np.tile(volume.ravel()[:, None], [1, 3])

# Do the plotting in a single call.
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(),
           y.ravel(),
           z.ravel(),
           c=c)
Mr Fooz
  • 109,094
  • 6
  • 73
  • 101
  • Thanks a lot for your answer! But it works with the same speed as **unutbu** solution. Only with arrays with size up to ~(100, 100, 100). But I'll try to clean it as much as I can to leave only nonzero points. – Dmitry Aug 31 '17 at 00:23
1

A similar solution can be achieved with product from itertools:

from itertools import product
from matplotlib import pyplot as plt
N = 8
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(projection="3d")
space = np.array([*product(range(N), range(N), range(N))]) # all possible triplets of numbers from 0 to N-1
volume = np.random.rand(N, N, N) # generate random data
ax.scatter(space[:,0], space[:,1], space[:,2], c=space/8, s=volume*300)

enter image description here

Nejc Jezersek
  • 621
  • 1
  • 9
  • 19