Background
Using an CNN autoencoder, I observe the projection of the latent space of a dataset of images. I'd like to hover over the 2D scatter plot and display the corresponding image. I also have the images true labels and would like to have it as legend (color scatter points).
Setup
My original images are contained in a 3D array X_plot
, my PCA reduced dataset is in X
, and I have a series of labels corresponding to the images in y
.
X_plot.shape = (n, 64, 64) # n images of 64x64
X.shape = (n, 2) # list of 2D coordinates for each image
y.shape = (n, ) # n labels
# Example code to reproduce
from matplotlib import pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np
n = 20
num_classes = 4
X_plot = np.random.rand(n, 64, 64)
X = np.random.rand(n, 2)
y = np.random.randint(num_classes, size=n)
Current code
Scatter with image display on hovering
This is largely inspired from this answer on StackOverFlow.
# Split 2D coordinates into list of xs and ys
xx, yy = zip(*X)
# create figure and plot scatter
fig = plt.figure()
ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker=".")
# create the annotations box
im = OffsetImage(X_plot[0,:,:], zoom=1, cmap='gray')
xybox=(50., 50.)
ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data',
boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->"))
# add it to the axes and make it invisible
ax.add_artist(ab)
ab.set_visible(False)
def hover(event):
# if the mouse is over the scatter points
if line.contains(event)[0]:
# find out the index within the array from the event
ind, = line.contains(event)[1]["ind"]
# get the figure size
w,h = fig.get_size_inches()*fig.dpi
ws = (event.x > w/2.)*-1 + (event.x <= w/2.)
hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
# if event occurs in the top or right quadrant of the figure,
# change the annotation box position relative to mouse.
ab.xybox = (xybox[0]*ws, xybox[1]*hs)
# make annotation box visible
ab.set_visible(True)
# place it at the position of the hovered scatter point
ab.xy =(xx[ind], yy[ind])
# set the image corresponding to that point
im.set_data(X_plot[ind,:,:])
else:
#if the mouse is not over a scatter point
ab.set_visible(False)
fig.canvas.draw_idle()
# add callback for mouse moves
fig.canvas.mpl_connect('motion_notify_event', hover)
plt.show()

Scatter with legend
If I want to display the 2D scatter with points colored and labeled with y, I use the following code:
fig = plt.figure()
ax = fig.add_subplot(111)
labels = np.unique(y)
for label in labels:
filtered_by_label = X[y == label]
ax.scatter(*zip(*filtered_by_label), s=12, marker='.', alpha=0.9, label=label)
ax.legend()
ax.axis('off')

Challenge
I can't get the two pieces of code above merged together. ax.plot
doesn't seem to accept a legend list as argument. Using the labels loop in the 2nd sub-solution, I would need to create the line
object that is used in the hover
function. However, I looked into merging several of them without success.
Any tips? Thanks!