1

I have a function, modified from an external source, that takes a list of images and displays them on a grid within a Jupyter notebook (python 3).

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  

    fig = plt.figure(figsize=(28, 28))
    gs = gridspec.GridSpec(28, 28)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_aspect('equal')
        plt.imshow(img.reshape([28, 28]))
    return

This takes in a number of mnist or similarly formatted data arrays (hence the tell-tale 28 image size) and is invoked every so often in the guts of a machine learning algorithm. Every time it is invoked, it adds another figure into the Jupyter window, so at the end, I have a long scrolling output.

What I would like is to modify this function (or create a companion one) so that it updates the existing figure in place rather than making new figures. Ideally it would be like a slowly evolving animation, with the frame rate set by the speed of the loop (which is relatively slow, on the order of seconds.)

I've tried various means of saving and re-using the figure, the gridspec, and the axes in conjunction with various permutations of drawing and re-showing these elements, but nothing works-- either I get no graphical output, or only a single static image which does not update (but also does not keep adding more figures.)

Novak
  • 4,687
  • 2
  • 26
  • 64

1 Answers1

0

A possible solution would be to return gs and pass it into the show_images function. The function could be modified such that if gs is not passed to it then it creates a new figure:

Something along the lines of:

def show_images(images, gs=None):
    images = np.reshape(images, [images.shape[0], -1])  

    if gs is None:
        fig = plt.figure(figsize=(28, 28))
        gs = gridspec.GridSpec(28, 28)
        gs.update(wspace=0.05, hspace=0.05)
    else:
        plt.clf()  # clear figure if one is already present

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_aspect('equal')
        plt.imshow(img.reshape(28, 28))

    return gs

Then on calling the function:

# first time calling the function. Don't pass gs argument
gs = show_images(images)

# Many lines of code later...

gs = show_images(images, gs)
DavidG
  • 24,279
  • 14
  • 89
  • 82