7

I follow this instruction and write the following code to create a Dataset for images(COCO2014 training set)

from pathlib import Path
import tensorflow as tf


def image_dataset(filepath, image_size, batch_size, norm=True):
    def preprocess_image(image):
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, image_size)
        if norm:
            image /= 255.0  # normalize to [0,1] range
        return image

    def load_and_preprocess_image(path):
        image = tf.read_file(path)
        return preprocess_image(image)

    all_image_paths = [str(f) for f in Path(filepath).glob('*')]
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
    ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.shuffle(buffer_size = len(all_image_paths))
    ds = ds.repeat()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

ds = image_dataset(train2014_dir, (256, 256), 4, False)
image = ds.make_one_shot_iterator().get_next('images')
# image is then fed to the network

This code will always run out of both memory(32G) and GPU(11G) and kill the process. Here is the messages shown on terminal. enter image description here

I also spot that the program get stuck at sess.run(opt_op). Where is wrong? How can I fix it?

Maybe
  • 2,129
  • 5
  • 25
  • 45
  • I think your `path_ds.map` function is taking up way to much space in memory, leaving little room for `shuffle` to do its job. Personally, when it comes to training on image datasets I prefer to use generators, which can help control how much RAM your preprocessing function uses. You might also want to run this on a smaller dataset just to see if you are getting the same error. – alif Jul 05 '19 at 03:55
  • I've tried a smaller dataset as you suggested, but suffering NaN error at `tf.clip_by_global_norm)`...I'm now trying to figure out why this happens. Do you have any idea how to make `map` works without changing dataset? – Maybe Jul 05 '19 at 04:55

1 Answers1

10

The problem is this:

ds = ds.shuffle(buffer_size = len(all_image_paths))

The buffer that Dataset.shuffle() uses is an 'in memory' buffer so you are effectively trying to load the whole dataset in memory.

You have a couple of options (which you can combine) to fix this:

Option 1:

Reduce the buffer size to a much smaller number.

Option 2:

Move the shuffle() statment before the map() statement.

This means we would be shuffling before we load the images therefore we'd just be storing the filenames in the memory buffer for the shuffle rather than storing huge tensors.

Stewart_R
  • 13,764
  • 11
  • 60
  • 106