I'm using Keras to train up a neural network, and I've reached the point where my data sets are getting bigger than the amount of RAM installed on my computer, so it's time to modify my training script to call model.fit_generator() instead of model.fit(), so that I don't have to load in all of the training and validation data into RAM at once.
I've made the modification, and AFAICT it is working fine, but there's one thing that is bothering me a little bit -- all of the example usages of fit_generator() I've seen online use Python's yield
feature to store the generator's state. I'm an old C++ programmer at heart and suspicious of features like yield
that I don't entirely understand, and therefore I wanted to maintain my generator's state explicitly rather than implicitly, so instead I implemented my generator like this:
class DataGenerator:
def __init__(self, inputFileName, maxExamplesPerBatch):
self._inputFileName = inputFileName
self._maxExamplesPerBatch = maxExamplesPerBatch
self._inputsFile = open(inputFileName, "rb")
if (self._inputsFile == None):
self._print("Couldn't open file %s to read input data" % inputFileName)
sys.exit(10)
self._outputsFile = open(inputFileName, "rb") # yes, we're deliberately opening the same file twice (to avoid having to call seek() a lot)
if (self._outputsFile == None):
self._print("Couldn't open file %s to read output data" % inputFileName)
sys.exit(10)
headerInfo = struct.unpack("<4L", self._inputsFile.read(16))
if (headerInfo[0] != 1414676815):
print("Bad magic number in input file [%s], aborting!" % inputFileName)
sys.exit(10)
self._numExamples = headerInfo[1] # Number of input->output rows in our data-file (typically quite large, i.e. millions)
self._numInputs = headerInfo[2] # Number of input values in each row
self._numOutputs = headerInfo[3] # Number of output values in row
self.seekToTopOfData()
def __len__(self):
return (math.ceil(self._numExamples/self._maxExamplesPerBatch))
def __next__(self):
numExamplesToLoad = self._maxExamplesPerBatch
numExamplesLeft = self._numExamples - self._curExampleIdx
if (numExamplesLeft < numExamplesToLoad):
numExamplesToLoad = numExamplesLeft
inputData = np.reshape(np.fromfile(self._inputsFile, dtype='<f4', count=(numExamplesToLoad*self._numInputs)), (numExamplesToLoad, self._numInputs))
outputData = np.reshape(np.fromfile(self._outputsFile, dtype='<f4', count=(numExamplesToLoad*self._numOutputs)), (numExamplesToLoad, self._numOutputs))
self._curExampleIdx += numExamplesToLoad
if (self._curExampleIdx == self._numExamples):
self.seekToTopOfData()
return (inputData, outputData) # <----- NOTE return, not yield!!
def seekToTopOfData(self):
self._curExampleIdx = 0
self._inputsFile.seek(16)
self._outputsFile.seek(16+(self._numExamples*self._numInputs*4))
[...]
trainingDataGenerator = DataGenerator(trainingInputFileName, maxExamplesPerBatch)
validationDataGenerator = DataGenerator(validationInputFileName, maxExamplesPerBatch)
model.fit_generator(generator=trainingDataGenerator, steps_per_epoch=len(trainingDataGenerator), epochs=maxEpochs, callbacks=callbacks_list, validation_data=validationDataGenerator, validation_steps=len(validationDataGenerator))
... note that my __next__(self) function ends with a return
rather than yield
, and that I'm storing the generator's state explicitly (via private-member-variables in the DataGenerator object) rather than implicitly (via yield
magic). This seems to work fine.
My question is, will this unusual approach introduce any non-obvious behavioral problems that I should be aware of?