80

In Python and Matplotlib, it is easy to either display the plot as a popup window or save the plot as a PNG file. How can I instead save the plot to a numpy array in RGB format?

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
user1003146
  • 801
  • 1
  • 7
  • 3

8 Answers8

106

This is a handy trick for unit tests and the like, when you need to do a pixel-to-pixel comparison with a saved plot.

One way is to use fig.canvas.tostring_rgb and then numpy.fromstring with the approriate dtype. There are other ways as well, but this is the one I tend to use.

E.g.

import matplotlib.pyplot as plt
import numpy as np

# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)

# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()

# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
Martin Valgur
  • 5,793
  • 1
  • 33
  • 45
Joe Kington
  • 275,208
  • 71
  • 604
  • 463
  • Is this only supported on certain backend? Does not seem to be working with `macosx` backend (`tostring_rgb`) not found. – mirosval Mar 05 '14 at 15:07
  • 6
    Works on Agg, add `matplotlib.use('agg')` before `import matplotlib.pyplot as plt` to use it. – mirosval Mar 05 '14 at 15:38
  • 11
    With images, the canvas adds a big margin, so I found it useful to insert `fig.tight_layout(pad=0)` before drawing. – Dan Allan Oct 14 '14 at 16:01
  • 2
    For figures with lines and text, it can also be important to turn antialiasing off. For lines ```plt.setp([ax.get_xticklines() + ax.get_yticklines() + ax.get_xgridlines() + ax.get_ygridlines()],antialiased=False)``` and for text ```mpl.rcParams['text.antialiased']=False``` – kmader Nov 03 '16 at 11:13
  • is it possible to store the the figure into to the same shape as the input dataset? – J.Down May 13 '17 at 12:49
  • There will be padding around figure. Various answers around for turning it off, but you can also just crop out the white from the resulting arrays with: data = data[:, np.where(~np.all(data==255, axis = 0))[0]] #cut all white rows data = data[np.where(~np.all(data==255, axis = 1,))[0], :] #and all white columns – JASON G PETERSON Feb 24 '18 at 17:30
  • 15
    @JoeKington `np.fromstring` with `sep=''` is deprecated since version 1.14. It should be replaced with `data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)` in future versions – OriolAbril Jun 08 '18 at 16:32
  • 3
    In case you run into `'FigureCanvasGTKAgg' object has no attribute 'renderer'`, remember to `matplotlib.use('Agg')`: https://stackoverflow.com/a/35407794/5339857 – Roy Shilkrot Feb 20 '19 at 23:47
  • 1
    instead `matplotlib.use('agg')` you can do `plt.switch_backend('agg')`, with this approach you won't have to run .use before `import matplotlib.pyplot as plt` – loknar Mar 28 '20 at 19:39
  • I accidentally unvoted my upvote to @loknar 's comment, but that is what I used in the end – MyCarta May 26 '20 at 15:47
  • Is print_to_buffer better than tostring_rgb? https://matplotlib.org/3.1.1/gallery/user_interfaces/canvasagg.html – CircaLucid Dec 16 '20 at 16:22
  • This does not work data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) Wrong shape – Marat Zakirov Jul 28 '21 at 09:32
  • @MaratZakirov see my answer for a solution. – Epimetheus Jan 18 '22 at 08:48
  • For some reason without the `fig.canvas.draw()` there are sizes mismatch, this was also mentioned by @Fabian Hertwig on Jonan Gueorguiev's answer – Nir May 30 '22 at 17:02
49

There is a bit simpler option for @JUN_NETWORKS's answer. Instead of saving the figure in png, one can use other format, like raw or rgba and skip the cv2 decoding step.

In other words the actual plot-to-numpy conversion boils down to:

io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                     newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()

Hope, this helps.

Jonan Gueorguiev
  • 1,146
  • 12
  • 20
  • 1
    I think this answer is far superior to the ones above: 1) It produces high-res images and 2) doesn't rely on external packages like cv2. – jrieke Nov 03 '20 at 23:54
  • 9
    I get a reshape error "cannot reshape array of size 3981312 into shape (480,640,newaxis)". Any ideas? – Fabian Hertwig Feb 18 '21 at 18:39
  • Indeed this answer is exactly what I was looking for ! Thank you ! – milembar Mar 07 '21 at 09:05
  • @FabianHertwig - make sure that not only the number of pixels, but also the number of (color) channels match. – Jonan Gueorguiev Mar 21 '21 at 08:56
  • 2
    @FabianHertwig I have the same problem, and here is the fix. When you create the fig you have to set dpi to be the same when you saved it `fig = plt.figure(figsize=(16, 4), dpi=128)` then `fig.savefig(io_buf, format='raw', dpi=128)` – Zen3515 May 18 '21 at 08:32
  • How to read this in 3 channel i.e. rgb, something like (128, 128, 3). Right now it is coming as (128,128,4) ? – Manish Jun 15 '21 at 11:49
  • 3
    This worked for me when omitting the `dpi` parameter, i.e. `fig.savefig(io_buf, format='raw')` – Mr.Epic Fail Jul 30 '21 at 10:09
  • for anyone trying this method. It is very problematic (saving the image as 'raw' and to the io_buffer yields a shape that makes no sense). The top answer works far better – Hersh Joshi Aug 02 '22 at 23:11
23

Some people propose a method which is like this

np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')

Ofcourse, this code work. But, output numpy array image is so low resolution.

My proposal code is this.

import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)

x = np.linspace(-np.pi, np.pi)

ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax.plot(x, np.sin(x), label="sin")

ax.legend()
ax.set_title("sin(x)")


# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)

This code works well.
You can get a high-resolution image as a numpy array if you set a large number on the dpi argument.

JUN_NETWORKS
  • 351
  • 3
  • 7
17

Time to benchmark your solutions.

import io
import matplotlib
matplotlib.use('agg')  # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot(range(10))


def plot1():
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))


def plot2():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='png')
        buff.seek(0)
        im = plt.imread(buff)


def plot3():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='raw')
        buff.seek(0)
        data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Under this scenario, IO raw buffers are the fastest to convert a matplotlib figure to a numpy array.

Additional remarks:

  • if you don't have an access to the figure, you can always extract it from the axes:

    fig = ax.figure

  • if you need the array in the channel x height x width format, do

    im = im.transpose((2, 0, 1)).

dizcza
  • 630
  • 1
  • 7
  • 19
6

MoviePy makes converting a figure to a numpy array quite simple. It has a built-in function for this called mplfig_to_npimage(). You can use it like this:

from moviepy.video.io.bindings import mplfig_to_npimage
import matplotlib.pyplot as plt

fig = plt.figure()  # make a figure
numpy_fig = mplfig_to_npimage(fig)  # convert it to a numpy array
Daniel Giger
  • 2,023
  • 21
  • 20
5

In case somebody wants a plug and play solution, without modifying any prior code (getting the reference to pyplot figure and all), the below worked for me. Just add this after all pyplot statements i.e. just before pyplot.show()

canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))
Nagabhushan S N
  • 6,407
  • 8
  • 44
  • 87
2

As Joe Kington has pointed out, one way is to draw on the canvas, convert the canvas to a byte string and then reshape it into the correct shape.

import matplotlib.pyplot as plt
import numpy as np
import math

plt.switch_backend('Agg')


def canvas2rgb_array(canvas):
    """Adapted from: https://stackoverflow.com/a/21940031/959926"""
    canvas.draw()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    ncols, nrows = canvas.get_width_height()
    scale = round(math.sqrt(buf.size / 3 / nrows / ncols))
    return buf.reshape(scale * nrows, scale * ncols, 3)


# Make a simple plot to test with
t = np.arange(0.0, 2.0, 0.01)
s = 1 + np.sin(2 * np.pi * t)
fig, ax = plt.subplots()
ax.plot(t, s)

# Extract the plot as an array
plt_array = canvas2rgb_array(fig.canvas)
print(plt_array.shape)

However as canvas.get_width_height() returns width and height in display coordinates, there are sometimes scaling issues that are resolved in this answer.

Epimetheus
  • 1,119
  • 1
  • 10
  • 19
0

Cleaned up version of the answer by Jonan Gueorguiev:

with io.BytesIO() as io_buf:
  fig.savefig(io_buf, format='raw', dpi=dpi)
  image = np.frombuffer(io_buf.getvalue(), np.uint8).reshape(
      int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)
Hugues
  • 2,865
  • 1
  • 27
  • 39