6

I am trying to create a custom data generator and don't know how integrate the yield function combined with an infinite loop inside the __getitem__ method.

EDIT: After the answer I realized that the code I am using is a Sequence which doesn't need a yield statement.

Currently I am returning multiple images with a return statement:

class DataGenerator(tensorflow.keras.utils.Sequence):
    def __init__(self, files, labels, batch_size=32, shuffle=True, random_state=42):
        'Initialization'
        self.files = files
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.random_state = random_state
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.files) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        files_batch = [self.files[k] for k in indexes]
        y = [self.labels[k] for k in indexes]

        # Generate data
        x = self.__data_generation(files_batch)

        return x, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.files))
        if self.shuffle == True:
            np.random.seed(self.random_state)
            np.random.shuffle(self.indexes)


    def __data_generation(self, files):
        imgs = []

        for img_file in files:

            img = cv2.imread(img_file, -1)

            ###############
            # Augment image
            ###############

            imgs.append(img) 

        return imgs

In this article I saw that yield is used in an infinite loop. I don't quite understand that syntax. How is the loop escaped?

oezguensi
  • 930
  • 1
  • 12
  • 23
  • Possible duplicate of [What does the "yield" keyword do?](https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do) – Triggernometry May 10 '19 at 14:05

2 Answers2

13

You are using the Sequence API, which works a bit different than plain generators. In a generator function, you would use the yield keyword to perform iteration inside a while True: loop, so each time Keras calls the generator, it gets a batch of data and it automatically wraps around the end of the data.

But in a Sequence, there is an index parameter to the __getitem__ function, so no iteration or yield is required, this is performed by Keras for you. This is made so the sequence can run in parallel using multiprocessing, which is not possible with old generator functions.

So you are doing things the right way, there is no change needed.

Dr. Snoopy
  • 55,122
  • 7
  • 121
  • 140
  • That was exactly what I couldn't understand. I used the example on https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly and saw the benefits. But I wasn't too sure what the code is doing exactly – oezguensi May 10 '19 at 14:30
8

Example of generator in Keras:

def datagenerator(images, labels, batchsize, mode="train"):
    while True:
        start = 0
        end = batchsize

        while start  < len(images): 
            # load your images from numpy arrays or read from directory
            x = images[start:end] 
            y = labels[start:end]
            yield x, y

            start += batchsize
            end += batchsize

Keras wants you to have the infinite loop running in the generator.

If you want to learn about Python generators, then the link in the comments is actually a good place to start.

Anakin
  • 1,889
  • 1
  • 13
  • 27
  • but how can I include this in my class? Because I am using __get_item – oezguensi May 10 '19 at 14:22
  • You can pass the generator in a `model.fit_generator()` – Anakin May 10 '19 at 14:23
  • 1
    You do no need the `__get_item` – Anakin May 10 '19 at 14:24
  • 1
    @Anakin That is not true, a Sequence needs a __getitem__ – Dr. Snoopy May 10 '19 at 14:25
  • 1
    I meant that you do not need the Sequence, you can also do with a data generator. The question even mentions a Keras data generator. – Anakin May 10 '19 at 17:55
  • 1
    @Anakin Note `keras` no longer requires the use of a `.fit_generator` method. You can directly pass the generator mentionned above in the `.fit` method. See [the documentation](https://keras.io/api/models/model_training_apis/#fit-method) for more details. – FraSchelle Jun 18 '20 at 14:13
  • Note: In the call to `.fit`, must specify `steps_per_epoch` or else it will get into an infinite loop. A good default for this is `number_of_samples // batch_size`. – Contango Feb 21 '23 at 11:51