2

I have a plotting function from a library that takes an array and generates a heatmap from it (I'll use plt.imshow here for the sake of the MWE). The function does not return anything: it just calls plt.show():

import matplotlib.pyplot as plt
import numpy as np

# Complicated function from a library which I technically could but should not modify
# simplified for MWE
def heatmap(arr):
    fig, ax = plt.subplots()
    _ = ax.imshow(arr)
    fig.show()

If I call this function on a loop, then I'll get multiple figures.

for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)

I want to collect these figures and animate them at the end, like:

plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    plots.append(plt.gca())  # what should this actually look like?

# wish this existed
plt.animate(plots) # ???

I do have access to the code for heatmap so I could technically change it to return the figure and axis, but I would like to find a simple solution which would work even if I had no access to the plotting code.

Is this possible with matplotlib? All examples I see in the docs suggest I have to update the figure, and not collect many different ones.

Daniel
  • 11,332
  • 9
  • 44
  • 72
  • Just to be crystal clear: the complicated function returns a figure, not the data? – Paul Brodersen Jun 17 '22 at 15:41
  • 1
    If you do have access to the data: [relevant matplotlib example](https://matplotlib.org/stable/gallery/animation/dynamic_image.html). – Paul Brodersen Jun 17 '22 at 15:41
  • as is it doesn't return anything. Great example - thanks. It does assume the function returns an ax object and doesn't create new figures. Would be a starting point. Any ideas if we can work around the no return? I will try to iterate on that example – Daniel Jun 17 '22 at 15:42
  • 1
    You could save out the images and assemble them into a gif or similar: https://stackoverflow.com/a/35943809/2912349 This probably only works well if the axis limits & labels remain constant. – Paul Brodersen Jun 17 '22 at 15:44
  • I research a bit [1](https://stackoverflow.com/q/42072704) [2](https://stackoverflow.com/q/45810557) [3](https://stackoverflow.com/q/15962849) , looks hard. (the last one is similar to this one but it has no solution. While axis are pickle-able, unpickling it appears to move it to a new figure instead of on an existing one) There's always the solution of patching plt.subplot() to plot on the existing figure but that's a bit ugly I guess. – user202729 Jun 18 '22 at 13:22
  • Also https://stackoverflow.com/q/6309472/5267751. – user202729 Jun 18 '22 at 13:25
  • @PaulBrodersen thank you your suggestions guided the way to a working solution. – Daniel Jun 19 '22 at 22:55

1 Answers1

0

Based on the comments I found a working solution to collect plots generated in a loop without having to access the plotting function, and saving them to an animation.

The original loop I was using was the following:

for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)

I'll first give the solution, and then a step-by-step explanation of the logic.

Final Solution

plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    plots.append([ax])

ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")

Step-by-step explanation

To save the plots and animate them for later, I had to do the following modifications:

  1. get a handle to the figures and axes generated within the figure:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    fig, ax = plt.gcf(), plt.gca()  # add this
  1. use the very first figure as a drawing canvas for all future axis:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:  # fig is the one we'll use for our animation canvas.
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()  # we will ignore dummy_fig
        plt.close(dummy_fig)
  1. before closing the other figures, move their axis to our main canvas
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.remove()  # remove ax from dummy_fig
        ax.figure = fig  # now assign it to our canvas fig
        fig.add_axes(ax)  # also patch the fig axes to know about it
        plt.close(dummy_fig)
  1. set the axes to be animated (doesn't seem to be strictly necessary though)
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)  # from plt example, but doesn't seem needed
        # we could however add info to each plot here, e.g.
        # ax.set(xlabel=f"image {i}")  # this could be done in i ==0 cond. too.
        ax.remove()
        ax.figure = fig 
        fig.add_axes(ax)
        plt.close(dummy_fig)
  1. Now simply collect all of these axes on a list, and plot them.
plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    plots.append([ax])

ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")
user202729
  • 3,358
  • 3
  • 25
  • 36
Daniel
  • 11,332
  • 9
  • 44
  • 72