0

Background

Using an CNN autoencoder, I observe the projection of the latent space of a dataset of images. I'd like to hover over the 2D scatter plot and display the corresponding image. I also have the images true labels and would like to have it as legend (color scatter points).

Setup

My original images are contained in a 3D array X_plot, my PCA reduced dataset is in X, and I have a series of labels corresponding to the images in y.

X_plot.shape  = (n, 64, 64)  # n images of 64x64
X.shape       = (n, 2)       # list of 2D coordinates for each image 
y.shape       = (n, )        # n labels
# Example code to reproduce
from matplotlib import pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np

n = 20
num_classes = 4
X_plot = np.random.rand(n, 64, 64)
X = np.random.rand(n, 2)
y = np.random.randint(num_classes, size=n)

Current code

Scatter with image display on hovering

This is largely inspired from this answer on StackOverFlow.

# Split 2D coordinates into list of xs and ys
xx, yy = zip(*X)
# create figure and plot scatter
fig = plt.figure()
ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker=".")

# create the annotations box
im = OffsetImage(X_plot[0,:,:], zoom=1, cmap='gray')
xybox=(50., 50.)
ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data',
        boxcoords="offset points",  pad=0.3,  arrowprops=dict(arrowstyle="->"))
# add it to the axes and make it invisible
ax.add_artist(ab)
ab.set_visible(False)

def hover(event):
    # if the mouse is over the scatter points
    if line.contains(event)[0]:
        # find out the index within the array from the event
        ind, = line.contains(event)[1]["ind"]
        # get the figure size
        w,h = fig.get_size_inches()*fig.dpi
        ws = (event.x > w/2.)*-1 + (event.x <= w/2.) 
        hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
        # if event occurs in the top or right quadrant of the figure,
        # change the annotation box position relative to mouse.
        ab.xybox = (xybox[0]*ws, xybox[1]*hs)
        # make annotation box visible
        ab.set_visible(True)
        # place it at the position of the hovered scatter point
        ab.xy =(xx[ind], yy[ind])
        # set the image corresponding to that point
        im.set_data(X_plot[ind,:,:])
    else:
        #if the mouse is not over a scatter point
        ab.set_visible(False)
    fig.canvas.draw_idle()

# add callback for mouse moves
fig.canvas.mpl_connect('motion_notify_event', hover)           
plt.show()

Scatter with legend

If I want to display the 2D scatter with points colored and labeled with y, I use the following code:

fig = plt.figure()
ax = fig.add_subplot(111)

labels = np.unique(y)
for label in labels:
    filtered_by_label = X[y == label]
    ax.scatter(*zip(*filtered_by_label), s=12, marker='.', alpha=0.9, label=label)

ax.legend()
ax.axis('off')

Challenge

I can't get the two pieces of code above merged together. ax.plot doesn't seem to accept a legend list as argument. Using the labels loop in the 2nd sub-solution, I would need to create the line object that is used in the hover function. However, I looked into merging several of them without success.

Any tips? Thanks!

Pierre
  • 113
  • 6

1 Answers1

0

I found a workaround by overlaying my two plots.

In the following section (scatter with hover):

ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker=".")

simply add the multiple scatter plots with legend.

ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker="") # no marker for this one
labels = np.unique(y)
for label in labels:
    filtered_by_label = X[y == label]
    ax.scatter(*zip(*filtered_by_label), s=12, marker='.', alpha=0.9, label=label)

The line object is still accessible by the hover function, and points are displayed in color!

Pierre
  • 113
  • 6