Context and examples of symptoms
I am using a neural network to do super-resolution (increase the resolution of images). However, since an image can be big, I need to segment it in multiple smaller images and make predictions on each one of those separately before merging the result back together.
Here are examples of what this gives me:
Example 1: you can see a subtle vertical line passing through the shoulder of the skier in the output picture.
Example 2: once you start seeing them, you'll notice that the subtle lines are forming squares throughout the whole image (remnants of the way I segmented the image for individual predictions).
Example 3: you can clearly see the vertical line crossing the lake.
Source of the problem
Basically, my network makes poor predictions along the edges, which I believe is normal since there is less "surrounding" information.
Source code
import numpy as np
import matplotlib.pyplot as plt
import skimage.io
from keras.models import load_model
from constants import verbosity, save_dir, overlap, \
model_name, tests_path, input_width, input_height
from utils import float_im
def predict(args):
model = load_model(save_dir + '/' + args.model)
image = skimage.io.imread(tests_path + args.image)[:, :, :3] # removing possible extra channels (Alpha)
print("Image shape:", image.shape)
predictions = []
images = []
crops = seq_crop(image) # crops into multiple sub-parts the image based on 'input_' constants
for i in range(len(crops)): # amount of vertical crops
for j in range(len(crops[0])): # amount of horizontal crops
current_image = crops[i][j]
images.append(current_image)
print("Moving on to predictions. Amount:", len(images))
for p in range(len(images)):
if p%3 == 0 and verbosity == 2:
print("--prediction #", p)
# Hack because GPU can only handle one image at a time
input_img = (np.expand_dims(images[p], 0)) # Add the image to a batch where it's the only member
predictions.append(model.predict(input_img)[0]) # returns a list of lists, one for each image in the batch
return predictions, image, crops
def show_pred_output(input, pred):
plt.figure(figsize=(20, 20))
plt.suptitle("Results")
plt.subplot(1, 2, 1)
plt.title("Input : " + str(input.shape[1]) + "x" + str(input.shape[0]))
plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.subplot(1, 2, 2)
plt.title("Output : " + str(pred.shape[1]) + "x" + str(pred.shape[0]))
plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.show()
# adapted from https://stackoverflow.com/a/52463034/9768291
def seq_crop(img):
"""
To crop the whole image in a list of sub-images of the same size.
Size comes from "input_" variables in the 'constants' (Evaluation).
Padding with 0 the Bottom and Right image.
:param img: input image
:return: list of sub-images with defined size
"""
width_shape = ceildiv(img.shape[1], input_width)
height_shape = ceildiv(img.shape[0], input_height)
sub_images = [] # will contain all the cropped sub-parts of the image
for j in range(height_shape):
horizontal = []
for i in range(width_shape):
horizontal.append(crop_precise(img, i*input_width, j*input_height, input_width, input_height))
sub_images.append(horizontal)
return sub_images
def crop_precise(img, coord_x, coord_y, width_length, height_length):
"""
To crop a precise portion of an image.
When trying to crop outside of the boundaries, the input to padded with zeros.
:param img: image to crop
:param coord_x: width coordinate (top left point)
:param coord_y: height coordinate (top left point)
:param width_length: width of the cropped portion starting from coord_x
:param height_length: height of the cropped portion starting from coord_y
:return: the cropped part of the image
"""
tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]
return float_im(tmp_img) # From [0,255] to [0.,1.]
# from https://stackoverflow.com/a/17511341/9768291
def ceildiv(a, b):
return -(-a // b)
# adapted from https://stackoverflow.com/a/52733370/9768291
def reconstruct(predictions, crops):
# unflatten predictions
def nest(data, template):
data = iter(data)
return [[next(data) for _ in row] for row in template]
if len(crops) != 0:
predictions = nest(predictions, crops)
H = np.cumsum([x[0].shape[0] for x in predictions])
W = np.cumsum([x.shape[1] for x in predictions[0]])
D = predictions[0][0]
recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
for d, s in zip(np.split(rd, W[:-1], 1), rs):
d[...] = s
return recon
if __name__ == '__main__':
print(" - ", args)
preds, original, crops = predict(args) # returns the predictions along with the original
enhanced = reconstruct(preds, crops) # reconstructs the enhanced image from predictions
plt.imsave('output/' + args.save, enhanced, cmap=plt.cm.gray)
show_pred_output(original, enhanced)
The question (what I want)
There are many obvious naive approaches to solving this problem, but I'm convinced there must be a very concise way of doing it: how do I add an overlap_amount
variable which would allow me to make overlapped predictions, thus discarding the "edge parts" of each sub-image ("segments") and replacing it with the result of the predictions on the segments surrounding it (since they would not contain "edge-predictions")?
I, of course, want to minimize the amount of "useless" predictions (pixels to be discarded). It might also be worth noting that the input segments produce an output segment which is 4 times bigger (i.e. if it was a 20x20 pixels image, you now get a 80x80 pixels image as output).