I'm attempting to modify the keras fit_generator() function to do gradient accumulation for an MRCNN model as I don't have a lot of GPU memory. I'll post the original code along with my modified version below.
Original:
def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
def train_on_batch(self, x, y,
sample_weight=None,
class_weight=None):
x, y, sample_weights = self._standardize_user_data(
x, y,
sample_weight=sample_weight,
class_weight=class_weight,
check_batch_axis=True)
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + y + sample_weights + [1.]
else:
ins = x + y + sample_weights
self._make_train_function()
outputs = self.train_function(ins)
if len(outputs) == 1:
return outputs[0]
else:
return outputs
def fit_generator(self, generator,
steps_per_epoch,
epochs=1,
mini_batches=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0):
wait_time = 0.01 # in seconds
epoch = initial_epoch
do_validation = bool(validation_data)
self._make_train_function()
if do_validation:
self._make_test_function()
# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_gen = (hasattr(validation_data, 'next') or
hasattr(validation_data, '__next__') or
isinstance(validation_data, Sequence))
if val_gen and not validation_steps:
raise ValueError('When using a generator for validation data, '
'you must specify a value for '
'`validation_steps`.')
# Prepare display labels.
out_labels = self._get_deduped_metrics_names()
callback_metrics = out_labels + ['val_' + n for n in out_labels]
# prepare callbacks
self.history = cbks.History()
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
if verbose:
callbacks += [cbks.ProgbarLogger(count_mode='steps')]
callbacks = cbks.CallbackList(callbacks)
# it's possible to callback a different model than self:
if hasattr(self, 'callback_model') and self.callback_model:
callback_model = self.callback_model
else:
callback_model = self
callbacks.set_model(callback_model)
callbacks.set_params({
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
})
callbacks.on_train_begin()
if do_validation and not val_gen:
if len(validation_data) == 2:
val_x, val_y = validation_data
val_sample_weight = None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data
else:
raise ValueError('`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: ' +
str(validation_data))
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x, val_y, val_sample_weight)
val_data = val_x + val_y + val_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
val_data += [0.]
for cbk in callbacks:
cbk.validation_data = val_data
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
enqueuer = None
try:
if is_sequence:
enqueuer = OrderedEnqueuer(generator,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle)
else:
enqueuer = GeneratorEnqueuer(generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
callback_model.stop_training = False
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
while steps_done < steps_per_epoch:
generator_output = next(output_generator)
if not hasattr(generator_output, '__len__'):
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
if len(generator_output) == 2:
x, y = generator_output
sample_weight = None
elif len(generator_output) == 3:
x, y, sample_weight = generator_output
else:
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
# build batch logs
batch_logs = {}
if isinstance(x, list):
batch_size = x[0].shape[0]
elif isinstance(x, dict):
batch_size = list(x.values())[0].shape[0]
else:
batch_size = x.shape[0]
batch_logs['batch'] = batch_index
batch_logs['size'] = batch_size
callbacks.on_batch_begin(batch_index, batch_logs)
outs = self.train_on_batch(x, y,
sample_weight=sample_weight,
class_weight=class_weight)
if not isinstance(outs, list):
outs = [outs]
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(batch_index, batch_logs)
# Construct epoch logs.
epoch_logs = {}
batch_index += 1
steps_done += 1
# Epoch finished.
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(
validation_data,
validation_steps,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing)
else:
# No need for try/except because
# data has already been validated.
val_outs = self.evaluate(
val_x, val_y,
batch_size=batch_size,
sample_weight=val_sample_weights,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
if callback_model.stop_training:
break
callbacks.on_epoch_end(epoch, epoch_logs)
epoch += 1
if callback_model.stop_training:
break
finally:
if enqueuer is not None:
enqueuer.stop()
callbacks.on_train_end()
return self.history
To modify it I essentially added a mini-batch parameter that is used by the _make_train_function such that it will not update the weights until that batch is reached.
The _make_train_function also now returns the updates value such that this can be maintained and continue to accumulate until the mini_batch value is reached, which is monitored by the current_batch value.
In the fit_generator() function I have it loop through the entire mini_batch for every epoch step. The idea is that until the self.train_function = K.function()
line happens the weights do not update, so I can just keep track of the updates for every image until the end of the mini_batch is reached then apply the update. I do not actually know if that is correct and my approach seems like it could be too simplistic to work, like I'm missing an important detail.
And modified version:
def _make_train_function(self , updates, current_batch, mini_batches):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
if self.train_function is None:
inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs += [K.learning_phase()]
with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = updates + self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
if current_batch == mini_batches:
self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
current_batch +=1
return updates
def train_on_batch(self, x, y,
sample_weight=None,
class_weight=None, updates = [], current_batch = 1, mini_batches = 1):
x, y, sample_weights = self._standardize_user_data(
x, y,
sample_weight=sample_weight,
class_weight=class_weight,
check_batch_axis=True)
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + y + sample_weights + [1.]
else:
ins = x + y + sample_weights
new_update = self._make_train_function(updates,current_batch, mini_batches)
if current_batch == mini_batches:
outputs = self.train_function(ins)
if len(outputs) == 1:
return outputs[0]
return outputs
else:
return new_update
def fit_generator(self, generator,
steps_per_epoch,
epochs=1,
mini_batches=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0):
wait_time = 0.01 # in seconds
epoch = initial_epoch
do_validation = bool(validation_data)
self._make_train_function([], 1, 1)
if do_validation:
self._make_test_function()
# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_gen = (hasattr(validation_data, 'next') or
hasattr(validation_data, '__next__') or
isinstance(validation_data, Sequence))
if val_gen and not validation_steps:
raise ValueError('When using a generator for validation data, '
'you must specify a value for '
'`validation_steps`.')
# Prepare display labels.
out_labels = self._get_deduped_metrics_names()
callback_metrics = out_labels + ['val_' + n for n in out_labels]
# prepare callbacks
self.history = cbks.History()
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
if verbose:
callbacks += [cbks.ProgbarLogger(count_mode='steps')]
callbacks = cbks.CallbackList(callbacks)
# it's possible to callback a different model than self:
if hasattr(self, 'callback_model') and self.callback_model:
callback_model = self.callback_model
else:
callback_model = self
callbacks.set_model(callback_model)
callbacks.set_params({
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
})
callbacks.on_train_begin()
if do_validation and not val_gen:
if len(validation_data) == 2:
val_x, val_y = validation_data
val_sample_weight = None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data
else:
raise ValueError('`validation_data` should be a tuple '
'`(val_x, val_y, val_sample_weight)` '
'or `(val_x, val_y)`. Found: ' +
str(validation_data))
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x, val_y, val_sample_weight)
val_data = val_x + val_y + val_sample_weights
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
val_data += [0.]
for cbk in callbacks:
cbk.validation_data = val_data
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
enqueuer = None
try:
if is_sequence:
enqueuer = OrderedEnqueuer(generator,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle)
else:
enqueuer = GeneratorEnqueuer(generator,
use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
callback_model.stop_training = False
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
while steps_done < steps_per_epoch:
current_batch = 1
updates = []
while current_batch <= mini_batches:
generator_output = next(output_generator)
if not hasattr(generator_output, '__len__'):
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
if len(generator_output) == 2:
x, y = generator_output
sample_weight = None
elif len(generator_output) == 3:
x, y, sample_weight = generator_output
else:
raise ValueError('Output of generator should be '
'a tuple `(x, y, sample_weight)` '
'or `(x, y)`. Found: ' +
str(generator_output))
# build batch logs
batch_logs = {}
if isinstance(x, list):
batch_size = x[0].shape[0]
elif isinstance(x, dict):
batch_size = list(x.values())[0].shape[0]
else:
batch_size = x.shape[0]
batch_logs['batch'] = batch_index
batch_logs['size'] = batch_size
callbacks.on_batch_begin(batch_index, batch_logs)
if current_batch == 1:
new_update = self.train_on_batch(x, y,
sample_weight=sample_weight,
class_weight=class_weight, updates = updates, current_batch = current_batch, mini_batches = mini_batches)
if current_batch >1 and current_batch < mini_batches:
new_update = self.train_on_batch(x, y,
sample_weight=sample_weight,
class_weight=class_weight, updates = new_update, current_batch = current_batch, mini_batches = mini_batches)
if current_batch == mini_batches:
outs = self.train_on_batch(x, y,
sample_weight=sample_weight,
class_weight=class_weight, updates = new_update, current_batch = current_batch, mini_batches = mini_batches)
current_batch +=1
if not isinstance(outs, list):
outs = [outs]
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(batch_index, batch_logs)
# Construct epoch logs.
epoch_logs = {}
batch_index += 1
steps_done += 1
# Epoch finished.
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(
validation_data,
validation_steps,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing)
else:
# No need for try/except because
# data has already been validated.
val_outs = self.evaluate(
val_x, val_y,
batch_size=batch_size,
sample_weight=val_sample_weights,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
if callback_model.stop_training:
break
callbacks.on_epoch_end(epoch, epoch_logs)
epoch += 1
if callback_model.stop_training:
break
finally:
if enqueuer is not None:
enqueuer.stop()
callbacks.on_train_end()
return self.history
I'm currently training and through the first epoch everything seemed like it was working, the loss had a healthy downward trend. Then, when the second epoch started the loss value on the first step went down very low to about 0.9 (whereas it had ended the first epoch around 4), then it jumped back up on the second step to 2.9, and continued a downward trend reaching 1.68 until step 91 (out of 143 steps per epoch) when it jumped back up to 2.9. I just have a feeling I'm missing something in my approach and would appreciate any assistance.