4

Context and examples of symptoms

I am using a neural network to do super-resolution (increase the resolution of images). However, since an image can be big, I need to segment it in multiple smaller images and make predictions on each one of those separately before merging the result back together.

Here are examples of what this gives me:

example 1 example 2 example 3

Example 1: you can see a subtle vertical line passing through the shoulder of the skier in the output picture.

Example 2: once you start seeing them, you'll notice that the subtle lines are forming squares throughout the whole image (remnants of the way I segmented the image for individual predictions).

Example 3: you can clearly see the vertical line crossing the lake.


Source of the problem

Basically, my network makes poor predictions along the edges, which I believe is normal since there is less "surrounding" information.


Source code

import numpy as np
import matplotlib.pyplot as plt
import skimage.io

from keras.models import load_model

from constants import verbosity, save_dir, overlap, \
    model_name, tests_path, input_width, input_height
from utils import float_im

def predict(args):
    model = load_model(save_dir + '/' + args.model)

    image = skimage.io.imread(tests_path + args.image)[:, :, :3]  # removing possible extra channels (Alpha)
    print("Image shape:", image.shape)

    predictions = []
    images = []

    crops = seq_crop(image)  # crops into multiple sub-parts the image based on 'input_' constants

    for i in range(len(crops)):  # amount of vertical crops
        for j in range(len(crops[0])):  # amount of horizontal crops
            current_image = crops[i][j]
            images.append(current_image)

    print("Moving on to predictions. Amount:", len(images))

    for p in range(len(images)):
        if p%3 == 0 and verbosity == 2:
            print("--prediction #", p)
        # Hack because GPU can only handle one image at a time
        input_img = (np.expand_dims(images[p], 0))       # Add the image to a batch where it's the only member
        predictions.append(model.predict(input_img)[0])  # returns a list of lists, one for each image in the batch

    return predictions, image, crops


def show_pred_output(input, pred):
    plt.figure(figsize=(20, 20))
    plt.suptitle("Results")

    plt.subplot(1, 2, 1)
    plt.title("Input : " + str(input.shape[1]) + "x" + str(input.shape[0]))
    plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.subplot(1, 2, 2)
    plt.title("Output : " + str(pred.shape[1]) + "x" + str(pred.shape[0]))
    plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.show()


# adapted from  https://stackoverflow.com/a/52463034/9768291
def seq_crop(img):
    """
    To crop the whole image in a list of sub-images of the same size.
    Size comes from "input_" variables in the 'constants' (Evaluation).
    Padding with 0 the Bottom and Right image.
    :param img: input image
    :return: list of sub-images with defined size
    """
    width_shape = ceildiv(img.shape[1], input_width)
    height_shape = ceildiv(img.shape[0], input_height)
    sub_images = []  # will contain all the cropped sub-parts of the image

    for j in range(height_shape):
        horizontal = []
        for i in range(width_shape):
            horizontal.append(crop_precise(img, i*input_width, j*input_height, input_width, input_height))
        sub_images.append(horizontal)

    return sub_images


def crop_precise(img, coord_x, coord_y, width_length, height_length):
    """
    To crop a precise portion of an image.
    When trying to crop outside of the boundaries, the input to padded with zeros.
    :param img: image to crop
    :param coord_x: width coordinate (top left point)
    :param coord_y: height coordinate (top left point)
    :param width_length: width of the cropped portion starting from coord_x
    :param height_length: height of the cropped portion starting from coord_y
    :return: the cropped part of the image
    """

    tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]

    return float_im(tmp_img)  # From [0,255] to [0.,1.]


# from  https://stackoverflow.com/a/17511341/9768291
def ceildiv(a, b):
    return -(-a // b)


# adapted from  https://stackoverflow.com/a/52733370/9768291
def reconstruct(predictions, crops):

    # unflatten predictions
    def nest(data, template):
        data = iter(data)
        return [[next(data) for _ in row] for row in template]

    if len(crops) != 0:
        predictions = nest(predictions, crops)

    H = np.cumsum([x[0].shape[0] for x in predictions])
    W = np.cumsum([x.shape[1] for x in predictions[0]])
    D = predictions[0][0]
    recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
    for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
        for d, s in zip(np.split(rd, W[:-1], 1), rs):
            d[...] = s
    return recon


if __name__ == '__main__':
    print("   -  ", args)

    preds, original, crops = predict(args)  # returns the predictions along with the original
    enhanced = reconstruct(preds, crops)    # reconstructs the enhanced image from predictions

    plt.imsave('output/' + args.save, enhanced, cmap=plt.cm.gray)

    show_pred_output(original, enhanced)

The question (what I want)

There are many obvious naive approaches to solving this problem, but I'm convinced there must be a very concise way of doing it: how do I add an overlap_amount variable which would allow me to make overlapped predictions, thus discarding the "edge parts" of each sub-image ("segments") and replacing it with the result of the predictions on the segments surrounding it (since they would not contain "edge-predictions")?

I, of course, want to minimize the amount of "useless" predictions (pixels to be discarded). It might also be worth noting that the input segments produce an output segment which is 4 times bigger (i.e. if it was a 20x20 pixels image, you now get a 80x80 pixels image as output).

payne
  • 4,691
  • 8
  • 37
  • 85
  • Why splitting the *image* to separate parts? So each part could be processed on another thread/process? Maybe the workload should be split on the network part. – Eran W May 09 '19 at 11:54
  • @EranW trying to pass a whole image through a neural network to get a prediction computed on my computer's GPU ends up giving me an `OOM` (Out Of Memory) error, that is why I need to split the image in separate parts and use the CPU to merge them all back together properly. – payne May 09 '19 at 12:41
  • I would start with the overlap approach (both in rows and cols) and try to find a value as small as possible to reduce the extra inferences. You still need to figure out how to mix the overlapped predictions (mean or max operators for example) – m33n May 10 '19 at 13:53

2 Answers2

1

I solved a similiar problem by moving inference into the CPU. It was much, much slower but at least in my case solved the patch border problems better than overlapping ROI voting- or discarding based approaches I also tested.

Assuming you are using the Tensorflow backend:

from tensorflow.python import device

with device('cpu:0')
    prediction = model.predict(...)

Of course assuming that you have enough RAM to fit your model. Comment below if that is not the case and I'll check out if there's something in my code that could be used here.

Tapio
  • 1,502
  • 1
  • 12
  • 24
  • Interestingly, I never even thought about this as a solution, but it is indeed legit. However, I would still prefer getting a GPU-oriented solution. – payne May 11 '19 at 07:29
  • Do you happen to have the code you had used for your different tests which led you to opt for this solution? – payne May 11 '19 at 11:58
  • Hitting you up with some news: trying out the CPU thing basically froze my computer and I had to restart it. – payne Aug 06 '19 at 12:16
  • Eh, sorry, I had completely forgot about this. Did you check whether it filled your whole RAM before crashing? – Tapio Aug 06 '19 at 14:07
  • I was in the process of opening the Task Manager when everything froze. I think it fairly safe to assume that is what happened, and I don't specifically feel like forcing my computer into another situation where I'll have to reboot it manually. Anyhow, I just got back on this project, and thought I'd try this lazy solution, but in reality I really want a segmented solution (and I just got onto working on the naive implementation). – payne Aug 06 '19 at 14:28
1

Solved it through a naive approach. It could be much better, but at least this works.

The process

Basically, it takes the initial image, then adds a padding around it, then crops it in multiple sub-images which are all lined up into an array. The crops are done so that all images overlap their surrounding neighbours as well.

Then, each image is fed into the network and the predictions are collected (4x on the resolution of the image, basically, in this case). When reconstructing the image, each prediction is taken individually and it's edge is cropped out (since it contains errors). The cropping is done so that the gluing of all the predictions ends up with no overlap, and only the middle parts of the predictions coming from the neural network are stuck together.

Finally, the surrounding padding is removed.

Result

No more line! :D

Proper prediction

Code

import numpy as np
import matplotlib.pyplot as plt
import skimage.io

from keras.models import load_model

from constants import verbosity, save_dir, overlap, \
    model_name, tests_path, input_width, input_height, scale_fact
from utils import float_im


def predict(args):
    """
    Super-resolution on the input image using the model.

    :param args:
    :return:
        'predictions' contains an array of every single cropped sub-image once enhanced (the outputs of the model).
        'image' is the original image, untouched.
        'crops' is the array of every single cropped sub-image that will be used as input to the model.
    """
    model = load_model(save_dir + '/' + args.model)

    image = skimage.io.imread(tests_path + args.image)[:, :, :3]  # removing possible extra channels (Alpha)
    print("Image shape:", image.shape)

    predictions = []
    images = []

    # Padding and cropping the image
    overlap_pad = (overlap, overlap)  # padding tuple
    pad_width = (overlap_pad, overlap_pad, (0, 0))  # assumes color channel as last
    padded_image = np.pad(image, pad_width, 'constant')  # padding the border
    crops = seq_crop(padded_image)  # crops into multiple sub-parts the image based on 'input_' constants

    # Arranging the divided image into a single-dimension array of sub-images
    for i in range(len(crops)):         # amount of vertical crops
        for j in range(len(crops[0])):  # amount of horizontal crops
            current_image = crops[i][j]
            images.append(current_image)

    print("Moving on to predictions. Amount:", len(images))
    upscaled_overlap = overlap * 2
    for p in range(len(images)):
        if p % 3 == 0 and verbosity == 2:
            print("--prediction #", p)

        # Hack due to some GPUs that can only handle one image at a time
        input_img = (np.expand_dims(images[p], 0))  # Add the image to a batch where it's the only member
        pred = model.predict(input_img)[0]          # returns a list of lists, one for each image in the batch

        # Cropping the useless parts of the overlapped predictions (to prevent the repeated erroneous edge-prediction)
        pred = pred[upscaled_overlap:pred.shape[0]-upscaled_overlap, upscaled_overlap:pred.shape[1]-upscaled_overlap]

        predictions.append(pred)
    return predictions, image, crops


def show_pred_output(input, pred):
    plt.figure(figsize=(20, 20))
    plt.suptitle("Results")

    plt.subplot(1, 2, 1)
    plt.title("Input : " + str(input.shape[1]) + "x" + str(input.shape[0]))
    plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.subplot(1, 2, 2)
    plt.title("Output : " + str(pred.shape[1]) + "x" + str(pred.shape[0]))
    plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.show()


# adapted from  https://stackoverflow.com/a/52463034/9768291
def seq_crop(img):
    """
    To crop the whole image in a list of sub-images of the same size.
    Size comes from "input_" variables in the 'constants' (Evaluation).
    Padding with 0 the Bottom and Right image.

    :param img: input image
    :return: list of sub-images with defined size (as per 'constants')
    """
    sub_images = []  # will contain all the cropped sub-parts of the image
    j, shifted_height = 0, 0
    while shifted_height < (img.shape[0] - input_height):
        horizontal = []
        shifted_height = j * (input_height - overlap)
        i, shifted_width = 0, 0
        while shifted_width < (img.shape[1] - input_width):
            shifted_width = i * (input_width - overlap)
            horizontal.append(crop_precise(img,
                                           shifted_width,
                                           shifted_height,
                                           input_width,
                                           input_height))
            i += 1
        sub_images.append(horizontal)
        j += 1

    return sub_images


def crop_precise(img, coord_x, coord_y, width_length, height_length):
    """
    To crop a precise portion of an image.
    When trying to crop outside of the boundaries, the input to padded with zeros.

    :param img: image to crop
    :param coord_x: width coordinate (top left point)
    :param coord_y: height coordinate (top left point)
    :param width_length: width of the cropped portion starting from coord_x (toward right)
    :param height_length: height of the cropped portion starting from coord_y (toward bottom)
    :return: the cropped part of the image
    """
    tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]
    return float_im(tmp_img)  # From [0,255] to [0.,1.]


# adapted from  https://stackoverflow.com/a/52733370/9768291
def reconstruct(predictions, crops):
    """
    Used to reconstruct a whole image from an array of mini-predictions.
    The image had to be split in sub-images because the GPU's memory
    couldn't handle the prediction on a whole image.

    :param predictions: an array of upsampled images, from left to right, top to bottom.
    :param crops: 2D array of the cropped images
    :return: the reconstructed image as a whole
    """

    # unflatten predictions
    def nest(data, template):
        data = iter(data)
        return [[next(data) for _ in row] for row in template]

    if len(crops) != 0:
        predictions = nest(predictions, crops)

    # At this point "predictions" is a 3D image of the individual outputs
    H = np.cumsum([x[0].shape[0] for x in predictions])
    W = np.cumsum([x.shape[1] for x in predictions[0]])
    D = predictions[0][0]
    recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
    for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
        for d, s in zip(np.split(rd, W[:-1], 1), rs):
            d[...] = s

    # Removing the pad from the reconstruction
    tmp_overlap = overlap * (scale_fact - 1)  # using "-2" leaves the outer edge-prediction error
    return recon[tmp_overlap:recon.shape[0]-tmp_overlap, tmp_overlap:recon.shape[1]-tmp_overlap]


if __name__ == '__main__':
    print("   -  ", args)

    preds, original, crops = predict(args)  # returns the predictions along with the original
    enhanced = reconstruct(preds, crops)    # reconstructs the enhanced image from predictions

    # Save and display the result
    plt.imsave('output/' + args.save, enhanced, cmap=plt.cm.gray)
    show_pred_output(original, enhanced)

Constants and extra bits

verbosity = 2

input_width = 64

input_height = 64

overlap = 16

scale_fact = 4

def float_im(img):
    return np.divide(img, 255.)

Alternative

A possibly better alternative which you might want to consider if you run into the same kind of problem as me; it's the same basic idea, but more polished and perfected.

payne
  • 4,691
  • 8
  • 37
  • 85