0

I'm using Keras to train up a neural network, and I've reached the point where my data sets are getting bigger than the amount of RAM installed on my computer, so it's time to modify my training script to call model.fit_generator() instead of model.fit(), so that I don't have to load in all of the training and validation data into RAM at once.

I've made the modification, and AFAICT it is working fine, but there's one thing that is bothering me a little bit -- all of the example usages of fit_generator() I've seen online use Python's yield feature to store the generator's state. I'm an old C++ programmer at heart and suspicious of features like yield that I don't entirely understand, and therefore I wanted to maintain my generator's state explicitly rather than implicitly, so instead I implemented my generator like this:

class DataGenerator:
   def __init__(self, inputFileName, maxExamplesPerBatch):
      self._inputFileName       = inputFileName
      self._maxExamplesPerBatch = maxExamplesPerBatch

      self._inputsFile = open(inputFileName, "rb")
      if (self._inputsFile == None):
         self._print("Couldn't open file %s to read input data" % inputFileName)
         sys.exit(10)

      self._outputsFile = open(inputFileName, "rb")   # yes, we're deliberately opening the same file twice (to avoid having to call seek() a lot)
      if (self._outputsFile == None):
         self._print("Couldn't open file %s to read output data" % inputFileName)
         sys.exit(10)

      headerInfo = struct.unpack("<4L", self._inputsFile.read(16))          
      if (headerInfo[0] != 1414676815):
         print("Bad magic number in input file [%s], aborting!" % inputFileName)
         sys.exit(10)

      self._numExamples   = headerInfo[1]  # Number of input->output rows in our data-file (typically quite large, i.e. millions)
      self._numInputs     = headerInfo[2]  # Number of input values in each row
      self._numOutputs    = headerInfo[3]  # Number of output values in row
      self.seekToTopOfData()

   def __len__(self):
      return (math.ceil(self._numExamples/self._maxExamplesPerBatch))

   def __next__(self):
      numExamplesToLoad = self._maxExamplesPerBatch
      numExamplesLeft   = self._numExamples - self._curExampleIdx
      if (numExamplesLeft < numExamplesToLoad):
         numExamplesToLoad = numExamplesLeft
      inputData  = np.reshape(np.fromfile(self._inputsFile,  dtype='<f4', count=(numExamplesToLoad*self._numInputs)),  (numExamplesToLoad, self._numInputs))
      outputData = np.reshape(np.fromfile(self._outputsFile, dtype='<f4', count=(numExamplesToLoad*self._numOutputs)), (numExamplesToLoad, self._numOutputs))
      self._curExampleIdx += numExamplesToLoad
      if (self._curExampleIdx == self._numExamples):
         self.seekToTopOfData()
      return (inputData, outputData)   # <----- NOTE return, not yield!!

   def seekToTopOfData(self):
      self._curExampleIdx = 0
      self._inputsFile.seek(16)
      self._outputsFile.seek(16+(self._numExamples*self._numInputs*4))

[...]

trainingDataGenerator   = DataGenerator(trainingInputFileName, maxExamplesPerBatch)
validationDataGenerator = DataGenerator(validationInputFileName, maxExamplesPerBatch)

model.fit_generator(generator=trainingDataGenerator, steps_per_epoch=len(trainingDataGenerator), epochs=maxEpochs, callbacks=callbacks_list, validation_data=validationDataGenerator, validation_steps=len(validationDataGenerator))

... note that my __next__(self) function ends with a return rather than yield, and that I'm storing the generator's state explicitly (via private-member-variables in the DataGenerator object) rather than implicitly (via yield magic). This seems to work fine.

My question is, will this unusual approach introduce any non-obvious behavioral problems that I should be aware of?

Jeremy Friesner
  • 70,199
  • 15
  • 131
  • 234

1 Answers1

1

A superficial examination of your code checks out. When you write a generator function and call it the call returns a generator whose __next__ method is typically repeatedly called by an iteration until it raises the StopIteration exception.

A generator is a special case of an iterator. Iterables like lists have an __iter__ method that produce an iterator.

Unless you want to send values into your generator as well as get them out, your DataGenerator is a reasonable way to implement an iterator, but to write an iterable you'd need another class whose __iter__ method returns an instance of DataGenerator.

The answers at How to implement __iter__(self) for a container object (Python) might also be helpful.

holdenweb
  • 33,305
  • 7
  • 57
  • 77