SOLUTION
What works for me is to normalize my matrices between 0 and 1 rather than 0 and 255. I have no idea why this works, but if someone has an explanation, I'll accept the answer.
ORIGINAL QUESTION
I have a dataset of images that are black rectangles on white fields. For example, here are three such images (sorry, no borders):
I want to build a scatter plot in which each data point is represented as an image. I used the code from this SO, but got a plot that looked like this (plotted with borders for debugging):
As you can see, the images all look like squares. In fact, they should be various sized rectangles with different positions inside the frame. Is there a way to fix this?
Here is some code that generates some images and then plots them as scatter plot points. You can see that the images it saves (you need to create an images
directory in the same directory from which you run the script) are different from the images that are plotted by matplotlib
:
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.patches as patches
import random
import numpy as np
# ------------------------------------------------------------------------------
def create_image():
rx = random.random()
ry = random.random()
rw = random.random() / 2
rh = random.random() / 2
DPI = 72
fig = plt.figure(frameon=False,
figsize=(32/DPI, 32/DPI),
dpi=DPI)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
rect = patches.Rectangle((rx, ry),
rw, rh,
linewidth=1,
edgecolor='none',
facecolor=(0, 0, 0))
ax.add_patch(rect)
fig.canvas.draw()
img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
width, height = fig.get_size_inches() * fig.get_dpi()
img = img.reshape((32, 32, 3))
img = img.T
fig.savefig('images/%s.png' % i)
plt.clf()
plt.close()
return img
# ------------------------------------------------------------------------------
def imscatter(points, images, ax):
for (x, y), image in zip(points, images):
im = OffsetImage(image.T, zoom=1)
ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.2)
ax.add_artist(ab)
# ------------------------------------------------------------------------------
fig, ax = plt.subplots()
# Create ten random points.
N_SAMPLES = 10
points = np.random.random((N_SAMPLES,2)) * 100
images = np.zeros((N_SAMPLES, 3, 32, 32))
for i in range(N_SAMPLES):
images[i] = create_image()
Xp = points[:, 0]
Yp = points[:, 1]
ax.set_xlim([Xp.min().astype(int), Xp.max().astype(int)])
ax.set_ylim([Yp.min().astype(int), Yp.max().astype(int)])
imscatter(points, images, ax)
plt.show()
From the comments:
@ImportanceOfBeingErnest gets this image when running my script locally, while I get this image. My guess is that this is related to DPI or resolution somehow.