0

I'm attempting to modify the keras fit_generator() function to do gradient accumulation for an MRCNN model as I don't have a lot of GPU memory. I'll post the original code along with my modified version below.

Original:

 def _make_train_function(self):

    if not hasattr(self, 'train_function'):
        raise RuntimeError('You must compile your model before using it.')
    if self.train_function is None:
        inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
        if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
            inputs += [K.learning_phase()]

        with K.name_scope('training'):
            with K.name_scope(self.optimizer.__class__.__name__):
                training_updates = self.optimizer.get_updates(
                    params=self._collected_trainable_weights,
                    loss=self.total_loss)
                updates = self.updates + training_updates
                
            # Gets loss and metrics. Updates weights at each call.

            
            self.train_function = K.function(inputs,
                                 [self.total_loss] + self.metrics_tensors,
                                  updates=updates,
                                  name='train_function',
                                  **self._function_kwargs)
    
 def train_on_batch(self, x, y,
                    sample_weight=None,
                    class_weight=None):


    x, y, sample_weights = self._standardize_user_data(
        x, y,
        sample_weight=sample_weight,
        class_weight=class_weight,
        check_batch_axis=True)
    if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = x + y + sample_weights + [1.]
    else:
        ins = x + y + sample_weights

    self._make_train_function()

    
    outputs = self.train_function(ins)
    if len(outputs) == 1:
        return outputs[0] 
    else:
        return outputs

    def fit_generator(self, generator,
                  steps_per_epoch,
                  epochs=1,
                  mini_batches=1,
                  verbose=1,
                  callbacks=None,
                  validation_data=None,
                  validation_steps=None,
                  class_weight=None,
                  max_queue_size=10,
                  workers=1,
                  use_multiprocessing=False,
                  shuffle=True,
                  initial_epoch=0):

    wait_time = 0.01  # in seconds
    epoch = initial_epoch

    do_validation = bool(validation_data)
    self._make_train_function()
    if do_validation:
        self._make_test_function()

    # python 2 has 'next', 3 has '__next__'
    # avoid any explicit version checks
    val_gen = (hasattr(validation_data, 'next') or
               hasattr(validation_data, '__next__') or
               isinstance(validation_data, Sequence))
    if val_gen and not validation_steps:
        raise ValueError('When using a generator for validation data, '
                         'you must specify a value for '
                         '`validation_steps`.')

    # Prepare display labels.
    out_labels = self._get_deduped_metrics_names()
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    self.history = cbks.History()
    callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
    if verbose:
        callbacks += [cbks.ProgbarLogger(count_mode='steps')]
    callbacks = cbks.CallbackList(callbacks)

    # it's possible to callback a different model than self:
    if hasattr(self, 'callback_model') and self.callback_model:
        callback_model = self.callback_model
    else:
        callback_model = self
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    if do_validation and not val_gen:
        if len(validation_data) == 2:
            val_x, val_y = validation_data
            val_sample_weight = None
        elif len(validation_data) == 3:
            val_x, val_y, val_sample_weight = validation_data
        else:
            raise ValueError('`validation_data` should be a tuple '
                             '`(val_x, val_y, val_sample_weight)` '
                             'or `(val_x, val_y)`. Found: ' +
                             str(validation_data))
        val_x, val_y, val_sample_weights = self._standardize_user_data(
            val_x, val_y, val_sample_weight)
        val_data = val_x + val_y + val_sample_weights
        if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
            val_data += [0.]
        for cbk in callbacks:
            cbk.validation_data = val_data
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    enqueuer = None

    try:
        if is_sequence:
            enqueuer = OrderedEnqueuer(generator,
                                       use_multiprocessing=use_multiprocessing,
                                       shuffle=shuffle)
        else:
            enqueuer = GeneratorEnqueuer(generator,
                                         use_multiprocessing=use_multiprocessing,
                                         wait_time=wait_time)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        callback_model.stop_training = False
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            steps_done = 0 
            batch_index = 0
            while steps_done < steps_per_epoch:
                generator_output = next(output_generator)

                if not hasattr(generator_output, '__len__'):
                    raise ValueError('Output of generator should be '
                                     'a tuple `(x, y, sample_weight)` '
                                     'or `(x, y)`. Found: ' +
                                     str(generator_output))
                if len(generator_output) == 2:
                    x, y = generator_output
                    sample_weight = None
                elif len(generator_output) == 3:
                    x, y, sample_weight = generator_output
                else:
                    raise ValueError('Output of generator should be '
                                         'a tuple `(x, y, sample_weight)` '
                                         'or `(x, y)`. Found: ' +
                                         str(generator_output))
                # build batch logs
                batch_logs = {}
                if isinstance(x, list):
                    batch_size = x[0].shape[0]
                elif isinstance(x, dict):
                    batch_size = list(x.values())[0].shape[0]
                else:
                    batch_size = x.shape[0]
                batch_logs['batch'] = batch_index
                batch_logs['size'] = batch_size
                callbacks.on_batch_begin(batch_index, batch_logs)

        
                outs = self.train_on_batch(x, y,
                                   sample_weight=sample_weight,
                                   class_weight=class_weight)
      

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)

                # Construct epoch logs.
                epoch_logs = {}
                batch_index += 1
                steps_done += 1

                # Epoch finished.
                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = self.evaluate_generator(
                            validation_data,
                            validation_steps,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing)
                    else:
                        # No need for try/except because
                        # data has already been validated.
                        val_outs = self.evaluate(
                            val_x, val_y,
                            batch_size=batch_size,
                            sample_weight=val_sample_weights,
                            verbose=0)
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

                if callback_model.stop_training:
                    break

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1
            if callback_model.stop_training:
                break

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    callbacks.on_train_end()
    return self.history

To modify it I essentially added a mini-batch parameter that is used by the _make_train_function such that it will not update the weights until that batch is reached.

The _make_train_function also now returns the updates value such that this can be maintained and continue to accumulate until the mini_batch value is reached, which is monitored by the current_batch value.

In the fit_generator() function I have it loop through the entire mini_batch for every epoch step. The idea is that until the self.train_function = K.function() line happens the weights do not update, so I can just keep track of the updates for every image until the end of the mini_batch is reached then apply the update. I do not actually know if that is correct and my approach seems like it could be too simplistic to work, like I'm missing an important detail.

And modified version:

 def _make_train_function(self , updates, current_batch, mini_batches):

    if not hasattr(self, 'train_function'):
        raise RuntimeError('You must compile your model before using it.')
    if self.train_function is None:
        inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
        if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
            inputs += [K.learning_phase()]

        with K.name_scope('training'):
            with K.name_scope(self.optimizer.__class__.__name__):
                training_updates = self.optimizer.get_updates(
                    params=self._collected_trainable_weights,
                    loss=self.total_loss)
             
                updates = updates + self.updates + training_updates
                
            # Gets loss and metrics. Updates weights at each call.

            if current_batch == mini_batches:
                self.train_function = K.function(inputs,
                                             [self.total_loss] + self.metrics_tensors,
                                             updates=updates,
                                             name='train_function',
                                             **self._function_kwargs)
            current_batch +=1
                
    return updates
    
    def train_on_batch(self, x, y,
                   sample_weight=None,
                   class_weight=None, updates = [], current_batch = 1, mini_batches = 1):

   
    x, y, sample_weights = self._standardize_user_data(
        x, y,
        sample_weight=sample_weight,
        class_weight=class_weight,
        check_batch_axis=True)
    if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
        ins = x + y + sample_weights + [1.]
    else:
        ins = x + y + sample_weights
    new_update = self._make_train_function(updates,current_batch, mini_batches)

    if current_batch == mini_batches:
        outputs = self.train_function(ins)
        if len(outputs) == 1:
            return outputs[0]
        return outputs
    else:
        return new_update

    def fit_generator(self, generator,
                  steps_per_epoch,
                  epochs=1,
                  mini_batches=1,
                  verbose=1,
                  callbacks=None,
                  validation_data=None,
                  validation_steps=None,
                  class_weight=None,
                  max_queue_size=10,
                  workers=1,
                  use_multiprocessing=False,
                  shuffle=True,
                  initial_epoch=0):
    
    wait_time = 0.01  # in seconds
    epoch = initial_epoch

    do_validation = bool(validation_data)
    self._make_train_function([], 1, 1)
    if do_validation:
        self._make_test_function()

    # python 2 has 'next', 3 has '__next__'
    # avoid any explicit version checks
    val_gen = (hasattr(validation_data, 'next') or
               hasattr(validation_data, '__next__') or
               isinstance(validation_data, Sequence))
    if val_gen and not validation_steps:
        raise ValueError('When using a generator for validation data, '
                         'you must specify a value for '
                         '`validation_steps`.')

    # Prepare display labels.
    out_labels = self._get_deduped_metrics_names()
    callback_metrics = out_labels + ['val_' + n for n in out_labels]

    # prepare callbacks
    self.history = cbks.History()
    callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
    if verbose:
        callbacks += [cbks.ProgbarLogger(count_mode='steps')]
    callbacks = cbks.CallbackList(callbacks)

    # it's possible to callback a different model than self:
    if hasattr(self, 'callback_model') and self.callback_model:
        callback_model = self.callback_model
    else:
        callback_model = self
    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': epochs,
        'steps': steps_per_epoch,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics,
    })
    callbacks.on_train_begin()

    if do_validation and not val_gen:
        if len(validation_data) == 2:
            val_x, val_y = validation_data
            val_sample_weight = None
        elif len(validation_data) == 3:
            val_x, val_y, val_sample_weight = validation_data
        else:
            raise ValueError('`validation_data` should be a tuple '
                             '`(val_x, val_y, val_sample_weight)` '
                             'or `(val_x, val_y)`. Found: ' +
                             str(validation_data))
        val_x, val_y, val_sample_weights = self._standardize_user_data(
            val_x, val_y, val_sample_weight)
        val_data = val_x + val_y + val_sample_weights
        if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
            val_data += [0.]
        for cbk in callbacks:
            cbk.validation_data = val_data
    is_sequence = isinstance(generator, Sequence)
    if not is_sequence and use_multiprocessing and workers > 1:
        warnings.warn(
            UserWarning('Using a generator with `use_multiprocessing=True`'
                        ' and multiple workers may duplicate your data.'
                        ' Please consider using the`keras.utils.Sequence'
                        ' class.'))
    enqueuer = None

    try:
        if is_sequence:
            enqueuer = OrderedEnqueuer(generator,
                                       use_multiprocessing=use_multiprocessing,
                                       shuffle=shuffle)
        else:
            enqueuer = GeneratorEnqueuer(generator,
                                         use_multiprocessing=use_multiprocessing,
                                         wait_time=wait_time)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        callback_model.stop_training = False
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            steps_done = 0 
            batch_index = 0
            while steps_done < steps_per_epoch:
                current_batch = 1
                updates = []
                while current_batch <= mini_batches:
                    generator_output = next(output_generator)

                    if not hasattr(generator_output, '__len__'):
                        raise ValueError('Output of generator should be '
                                         'a tuple `(x, y, sample_weight)` '
                                         'or `(x, y)`. Found: ' +
                                         str(generator_output))
                    if len(generator_output) == 2:
                        x, y = generator_output
                        sample_weight = None
                    elif len(generator_output) == 3:
                        x, y, sample_weight = generator_output
                    else:
                        raise ValueError('Output of generator should be '
                                         'a tuple `(x, y, sample_weight)` '
                                         'or `(x, y)`. Found: ' +
                                         str(generator_output))
                    # build batch logs
                    batch_logs = {}
                    if isinstance(x, list):
                        batch_size = x[0].shape[0]
                    elif isinstance(x, dict):
                        batch_size = list(x.values())[0].shape[0]
                    else:
                        batch_size = x.shape[0]
                    batch_logs['batch'] = batch_index
                    batch_logs['size'] = batch_size
                    callbacks.on_batch_begin(batch_index, batch_logs)

                    if current_batch == 1:
                        new_update = self.train_on_batch(x, y,
                                           sample_weight=sample_weight,
                                           class_weight=class_weight, updates = updates, current_batch = current_batch, mini_batches = mini_batches)
                    if current_batch >1 and current_batch < mini_batches:
                        new_update = self.train_on_batch(x, y,
                                           sample_weight=sample_weight,
                                           class_weight=class_weight, updates = new_update, current_batch = current_batch, mini_batches = mini_batches)
                    if current_batch == mini_batches:
                        outs = self.train_on_batch(x, y,
                                           sample_weight=sample_weight,
                                           class_weight=class_weight, updates = new_update, current_batch = current_batch, mini_batches = mini_batches)

                    current_batch +=1

                if not isinstance(outs, list):
                    outs = [outs]
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks.on_batch_end(batch_index, batch_logs)

                # Construct epoch logs.
                epoch_logs = {}
                batch_index += 1
                steps_done += 1

                # Epoch finished.
                if steps_done >= steps_per_epoch and do_validation:
                    if val_gen:
                        val_outs = self.evaluate_generator(
                            validation_data,
                            validation_steps,
                            max_queue_size=max_queue_size,
                            workers=workers,
                            use_multiprocessing=use_multiprocessing)
                    else:
                        # No need for try/except because
                        # data has already been validated.
                        val_outs = self.evaluate(
                            val_x, val_y,
                            batch_size=batch_size,
                            sample_weight=val_sample_weights,
                            verbose=0)
                    if not isinstance(val_outs, list):
                        val_outs = [val_outs]
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

                if callback_model.stop_training:
                    break

            callbacks.on_epoch_end(epoch, epoch_logs)
            epoch += 1
            if callback_model.stop_training:
                break

    finally:
        if enqueuer is not None:
            enqueuer.stop()

    callbacks.on_train_end()
    return self.history

I'm currently training and through the first epoch everything seemed like it was working, the loss had a healthy downward trend. Then, when the second epoch started the loss value on the first step went down very low to about 0.9 (whereas it had ended the first epoch around 4), then it jumped back up on the second step to 2.9, and continued a downward trend reaching 1.68 until step 91 (out of 143 steps per epoch) when it jumped back up to 2.9. I just have a feeling I'm missing something in my approach and would appreciate any assistance.

  • Does this answer your question? [Gradient Accumulation with Custom model.fit in TF.Keras?](https://stackoverflow.com/questions/66472201/gradient-accumulation-with-custom-model-fit-in-tf-keras) – Innat Oct 02 '22 at 02:33
  • Please trim your code to make it easier to find your problem. Follow these guidelines to create a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). – Community Oct 02 '22 at 03:23
  • No, the gradient accumulation needs to be tailored for fit_generator – woods0813 Oct 02 '22 at 16:04

0 Answers0