1

Note:

After experimenting I noticed that this problem only occurs when I am training on the GPU. I created a github issue (#50454). At this point I am not sure what is happening exactly.

I am working on an implementation for Gradient Accumulation. However, none of the approaches seem to work. Below I am describing two approaches which could work theoretically but it appears to conflict with Tensorflow.

The idea

I want to patch an arbitrary Optimizer-instance by replacing its apply_gradients() function by my own implementation which accumulates gradients.

# Build model first
model.build()

# Patch the optimizer
optimizer = get_patched_optimizer(optimizer, n, model.trainable_variables)

# Compile the model with the patched optimizer
model.compile(optimizer=optimizer)

where

def get_patched_optimizer(optimizer, n, trainable_variables):
    """Patch optimizer for gradient accumulation.

    :param optimizer:
        The optimizer to patch.
    :param n:
        The number of accumulation steps before applying gradients.
    :param trainable_variables:
        Trainable parameters of the model
    :return:
        A patched patched optimizer for gradient accumulation.
    """
    accumulator = _GradientAccumulationPatch(
        n=n,
        orig_apply_gradients=optimizer.apply_gradients,
        trainable_variables=trainable_variables
    )

    # Replace the original function
    optimizer.apply_gradients = accumulator.apply_gradients

    return optimizer

The happy (but not working) path

The simplest way would be to just accumulate gradients and apply gradients conditionally e.g. whenever current_step % n == 0.

However, the problem here is that it looks like I am not able to use tf.cond() in this context in contrast to how they're doing it in Gradient Accumulation with Custom model.fit in TF.Keras?.

Using tf.cond() results in the following RuntimeError

RuntimeError: merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function, and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn

Here is the implementation of _GradientAccumulationPatch using tf.cond():

class _GradientAccumulationPatch:

    def __init__(
        self,
        n: int,
        orig_apply_gradients,
        trainable_variables
    ):
        self.n = tf.constant(n, dtype=tf.int64)
        policy = tf.keras.mixed_precision.global_policy()
        self.variable_dtype = policy.variable_dtype
        self.accu_gradients = [
            tf.Variable(
                tf.zeros(g.shape, dtype=g.dtype),
            ) for g in trainable_variables
        ]

        self._current_step = tf.Variable(0, dtype=tf.int64)
        self._orig_apply_gradients = orig_apply_gradients

    def apply_gradients(self, grads_and_vars, *args, **kwargs):

        trainable_variables = [var for (_, var) in grads_and_vars]
        gradients = [grad for (grad, _) in grads_and_vars]

        # Always accumulate gradients
        for i, grad in enumerate(gradients):
            self.accu_gradients[i].assign_add(grad)

        tf.cond(
            self._can_apply_on_next_step(),
            true_fn=lambda: self.apply_accu_gradients(trainable_variables, args, kwargs),
            false_fn=lambda: None
        )

    def apply_accu_gradients(self, trainable_variables, *args, **kwargs):

        # Call the original apply_gradients() function
        self._orig_apply_gradients(zip(self.accu_gradients, trainable_variables), *args, **kwargs)

        # Reset all accumulated gradients to zero
        for i in range(len(self.accu_gradients)):
            self.accu_gradients[i].assign(tf.zeros_like(trainable_variables[i]))

    def _can_apply_on_next_step(self):
        """
        :return: True if gradients should be applied; False otherwise.
        """
        # Increment (always do this first)
        self._current_step.assign_add(1)
        count_mod_steps = tf.math.mod(self._current_step, self.n)
        return tf.equal(count_mod_steps, 0)

The more complicated path (also not working)

It is possible to remove the tf.cond() by simply using the signal apply, given by _can_apply_on_next_step(), as a multiplication factor and apply zero-gradients whenever we are in the accumulation-phase.

The idea would be to always accumulate gradients and always apply them with one particular change:

final_gradients = [grad * apply for grad in gradients]
self._orig_apply_gradients(zip(final_gradients, trainable_variables))

This is how we'd change the apply_gradients() method:

def apply_gradients(self, grads_and_vars, *args, **kwargs):

    can_apply = self._can_apply_on_next_step()
    # 1.0 whenever we want to apply gradients; 0.0 otherwise
    apply = tf.cast(can_apply, dtype=self.variable_dtype)
    # Will be 0.0 if apply is 1.0 and vice versa
    keep = tf.cast(tf.logical_not(can_apply), dtype=self.variable_dtype)

    grads_and_vars = list(grads_and_vars)
    gradients = [grad for (grad, _) in grads_and_vars]
    trainable_variables = [var for (_, var) in grads_and_vars]

    # Accumulate gradients
    for i, grad in enumerate(gradients):
        self.accu_gradients[i].assign_add(grad)

    # Multiply each gradient with our apply-signal
    final_gradients = [grad * apply for grad in self.accu_gradients]

    self._orig_apply_gradients(zip(final_gradients, trainable_variables), *args, **kwargs)

    # This will reset our buffer whenever "keep" is 0.0
    for g in self.accu_gradients:
        g.assign(g * keep)

But the problem is that self.accu_gradients[i].assign_add(grad) does not seem to have any effect. And yes, I have also tried

self.accu_gradients[i].assign(grad + self.accu_gradients[i])

Interestingly, the model starts to converge if I use assign(grad) instead as in self.accu_gradients[i].assign_add(grad) as you can see:

blue: just using assign()   # <- no accumulation happening
red:  using assign_add()

enter image description here

The train_step()

This patch should work model independently. I do have a custom train_step() for my model though but the implementation is pretty straight forward.

Here I am just computing the gradients and then all the apply_gradients() method of the optimizer:

def train_step(self, data):

    (inputs, (input_lengths, label_lengths), mask), y_true = data

    loss, gradients = self.rnnt_gradient(
        inputs=inputs,
        y_true=y_true,
        input_lengths=input_lengths,
        label_lengths=label_lengths,
        mask=mask
    )

    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

    return {'loss': loss}

def test_step(self, data):

    (inputs, (input_lengths, label_lengths), mask), y_true = data

    val_loss = self.rnnt_loss_wrapper(
        inputs=inputs,
        y_true=y_true,
        input_lengths=input_lengths,
        label_lengths=label_lengths,
        mask=mask
    )

    return dict(loss=val_loss)


def rnnt_gradient(
    self,
    inputs: tuple,
    y_true: tf.Tensor,
    input_lengths: tf.Tensor,
    label_lengths: tf.Tensor,
    mask=None
):
    with tf.GradientTape() as tape:
        model_loss = self.rnnt_loss_wrapper(
            inputs,
            y_true=y_true,
            input_lengths=input_lengths,
            label_lengths=label_lengths,
            mask=mask
        )

        is_mixed_precision = isinstance(self.optimizer, mixed_precision.LossScaleOptimizer)

        # We always want to return the unmodified model_loss for Tensorboard
        if is_mixed_precision:
            loss = self.optimizer.get_scaled_loss(model_loss)
        else:
            loss = model_loss

        gradients = tape.gradient(loss, self.trainable_variables)

        if is_mixed_precision:
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        return model_loss, gradients
Stefan Falk
  • 23,898
  • 50
  • 191
  • 378
  • Does this answer your question? [Gradient Accumulation with Custom model.fit in TF.Keras?](https://stackoverflow.com/questions/66472201/gradient-accumulation-with-custom-model-fit-in-tf-keras) – Innat Jun 25 '21 at 00:14
  • @M.Innat In essence, it's quite the same what they're doing but I want to implement GA in an optimizer-patch and not on the model-side. The problem I face when using their approach is that `tf.cond` does not work in this context. Idk why this works for them. – Stefan Falk Jun 25 '21 at 06:00
  • Have you checked [this](https://stackoverflow.com/a/55281501/9215780)? – Innat Jun 25 '21 at 06:43
  • @M.Innat Yes, I have seen that too. I have also updated my question to clarify what I have tried and what exactly is not working in my case. – Stefan Falk Jun 25 '21 at 07:04

1 Answers1

0

It turns out that this was totally my fault and it was due to the fact that whenever I trained with a mixed_float16 policy, I would have patched the wrong instance.

What I had was something like:

if precision_policy.name.startswith('mixed'):
    logger.info(f'Using LossScaleOptimizer (policy: "{precision_policy.name})"')
    optimizer = keras.mixed_precision.LossScaleOptimizer(optimizer)

if grad_acc_n > 1:
    # --> This patched the LossScaleOptimizer which caused the problem:
    optimizer = grad_acc.get_patched_optimizer(optimizer=optimizer, n=grad_acc_n)

So I would require a check like:

if isinstance(optimizer, keras.mixed_precision.LossScaleOptimizer):
    # Warning: This does NOT work either (just an example)!
    optimizer.inner_optimizer.apply_gradients = accumulator.apply_gradients
    raise Exception('Don\'t do this!')
else:
    optimizer.apply_gradients = accumulator.apply_gradients

However, as stated in the comment, patching the inner_optimizer does not work either. I haven't figured out why but at least I am now able to run a "normal" float32-policy training with my _GradientAccumulationPatch-implementation.

Stefan Falk
  • 23,898
  • 50
  • 191
  • 378