Based on the comments I found a working solution to collect plots generated in a loop without having to access the plotting function, and saving them to an animation.
The original loop I was using was the following:
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
I'll first give the solution, and then a step-by-step explanation of the logic.
Final Solution
plots = []
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.set(animated=True)
ax.remove()
ax.figure = fig
fig.add_axes(ax)
plt.close(dummy_fig)
plots.append([ax])
ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")
Step-by-step explanation
To save the plots and animate them for later, I had to do the following modifications:
- get a handle to the figures and axes generated within the figure:
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
fig, ax = plt.gcf(), plt.gca() # add this
- use the very first figure as a drawing canvas for all future axis:
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0: # fig is the one we'll use for our animation canvas.
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca() # we will ignore dummy_fig
plt.close(dummy_fig)
- before closing the other figures, move their axis to our main canvas
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.remove() # remove ax from dummy_fig
ax.figure = fig # now assign it to our canvas fig
fig.add_axes(ax) # also patch the fig axes to know about it
plt.close(dummy_fig)
- set the axes to be animated (doesn't seem to be strictly necessary though)
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.set(animated=True) # from plt example, but doesn't seem needed
# we could however add info to each plot here, e.g.
# ax.set(xlabel=f"image {i}") # this could be done in i ==0 cond. too.
ax.remove()
ax.figure = fig
fig.add_axes(ax)
plt.close(dummy_fig)
- Now simply collect all of these axes on a list, and plot them.
plots = []
for i in range(100):
arr = np.random.rand(10,10)
heatmap(arr)
if i==0:
fig, ax = plt.gcf(), plt.gca()
else:
dummy_fig, ax = plt.gcf(), plt.gca()
ax.set(animated=True)
ax.remove()
ax.figure = fig
fig.add_axes(ax)
plt.close(dummy_fig)
plots.append([ax])
ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")