1

I wrote a function with this purpose:

  1. to create a matplotlib figure, but not display it
  2. with no frames, axes, etc.
  3. to plot in the figure an input 2D array using a user-passed colormap
  4. to save the colormapped 2D array from the canvas to a numpy array
  5. that the output array should be the same size as the input

There are lots of questions with answers for tasks similar to either points 1-2 or point 4; for me it was also important to automate point 5. So I started by combining parts from both @joe-kington 's answer and from @matehat 's answer and comments to it, and with small modifications I got to this:

def mk_cmapped_data(data, mpl_cmap_name):
    
    # This is to define figure & ouptput dimensions from input
    r, c = data.shape
    dpi = 72
    w = round(c/dpi, 2)
    h = round(r/dpi, 2)    
    
    # This part modified from @matehat's SO answer: 
    # https://stackoverflow.com/a/8218887/1034648
    fig = plt.figure(frameon=False)
    fig.set_size_inches((w, h)) 
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    plt.set_cmap(mpl_cmap_name)
    ax.imshow(data, aspect='auto', cmap = mpl_cmap_name, interpolation = 'none')
    fig.canvas.draw()    
    
    # This part is to save the canvas to numpy array
    # Adapted rom Joe Kington's SO answer: 
    # https://stackoverflow.com/a/7821917/1034648
    mat = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    mat = mat.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    mat = normalise(mat) # this is just using a helper function to normalize output range
    plt.close(fig=None)
    return mat

The function does what it is supposed to do and is fast enough. My question is whether I can make it more efficient and or more pythonic in any way.

MyCarta
  • 808
  • 2
  • 12
  • 37

2 Answers2

5

If you're wanting RGB output that exactly matches the shape of the input array, it's probably easiest to not create a figure, and instead use the colormap objects directly. For example:

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Random data with a non 0-1 range.
data = 500 * np.random.random((100, 100)) - 200

# We'll use `LinearSegementedColormap` and `Normalize` instances directly
cmap = plt.get_cmap('viridis')
norm = plt.Normalize(data.min(), data.max())

# The norm instance scales data to a 0-1 range, cmap makes it RGB
rgb = cmap(norm(data))  

# MPL uses a 0-1 float RGB representation, so we'll scale to 0-255
rgb = (255 * rgb).astype(np.uint8) 

Image.fromarray(rgb).save('test.png')

Note that you likely don't want the additional step of saving it as a PNG, but I wanted to be able to show the result visually. This is exactly a 100x100 image where each pixel corresponds to the original input data.

100x100 RGB colormapped version of array

This is what matplotlib does behind-the-scenes when you call imshow. The data is first run through a Normalize instance to scale it from its original range to 0-1. Then any Colormap instance can be called directly with the 0-1 results to turn the scalar data into RGB data.

Joe Kington
  • 275,208
  • 71
  • 604
  • 463
1

One letter variables are hard to understand.

Change:

r -> n_rows
c -> n_cols
w -> width
h -> height
TheSaint321
  • 390
  • 2
  • 4
  • 17