Note:
After experimenting I noticed that this problem only occurs when I am training on the GPU. I created a github issue (#50454). At this point I am not sure what is happening exactly.
I am working on an implementation for Gradient Accumulation. However, none of the approaches seem to work. Below I am describing two approaches which could work theoretically but it appears to conflict with Tensorflow.
The idea
I want to patch an arbitrary Optimizer
-instance by replacing its apply_gradients()
function by my own implementation which accumulates gradients.
# Build model first
model.build()
# Patch the optimizer
optimizer = get_patched_optimizer(optimizer, n, model.trainable_variables)
# Compile the model with the patched optimizer
model.compile(optimizer=optimizer)
where
def get_patched_optimizer(optimizer, n, trainable_variables):
"""Patch optimizer for gradient accumulation.
:param optimizer:
The optimizer to patch.
:param n:
The number of accumulation steps before applying gradients.
:param trainable_variables:
Trainable parameters of the model
:return:
A patched patched optimizer for gradient accumulation.
"""
accumulator = _GradientAccumulationPatch(
n=n,
orig_apply_gradients=optimizer.apply_gradients,
trainable_variables=trainable_variables
)
# Replace the original function
optimizer.apply_gradients = accumulator.apply_gradients
return optimizer
The happy (but not working) path
The simplest way would be to just accumulate gradients and apply gradients conditionally e.g. whenever current_step % n == 0
.
However, the problem here is that it looks like I am not able to use tf.cond()
in this context in contrast to how they're doing it in Gradient Accumulation with Custom model.fit in TF.Keras?.
Using tf.cond()
results in the following RuntimeError
RuntimeError:
merge_call
called while defining a new graph or a tf.function. This can often happen if the functionfn
passed tostrategy.run()
contains a nested@tf.function
, and the nested@tf.function
contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the functionfn
uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nestedtf.function
s or control flow statements that may potentially cross a synchronization boundary, for example, wrap thefn
passed tostrategy.run
or the entirestrategy.run
inside atf.function
or move the control flow out offn
Here is the implementation of _GradientAccumulationPatch
using tf.cond()
:
class _GradientAccumulationPatch:
def __init__(
self,
n: int,
orig_apply_gradients,
trainable_variables
):
self.n = tf.constant(n, dtype=tf.int64)
policy = tf.keras.mixed_precision.global_policy()
self.variable_dtype = policy.variable_dtype
self.accu_gradients = [
tf.Variable(
tf.zeros(g.shape, dtype=g.dtype),
) for g in trainable_variables
]
self._current_step = tf.Variable(0, dtype=tf.int64)
self._orig_apply_gradients = orig_apply_gradients
def apply_gradients(self, grads_and_vars, *args, **kwargs):
trainable_variables = [var for (_, var) in grads_and_vars]
gradients = [grad for (grad, _) in grads_and_vars]
# Always accumulate gradients
for i, grad in enumerate(gradients):
self.accu_gradients[i].assign_add(grad)
tf.cond(
self._can_apply_on_next_step(),
true_fn=lambda: self.apply_accu_gradients(trainable_variables, args, kwargs),
false_fn=lambda: None
)
def apply_accu_gradients(self, trainable_variables, *args, **kwargs):
# Call the original apply_gradients() function
self._orig_apply_gradients(zip(self.accu_gradients, trainable_variables), *args, **kwargs)
# Reset all accumulated gradients to zero
for i in range(len(self.accu_gradients)):
self.accu_gradients[i].assign(tf.zeros_like(trainable_variables[i]))
def _can_apply_on_next_step(self):
"""
:return: True if gradients should be applied; False otherwise.
"""
# Increment (always do this first)
self._current_step.assign_add(1)
count_mod_steps = tf.math.mod(self._current_step, self.n)
return tf.equal(count_mod_steps, 0)
The more complicated path (also not working)
It is possible to remove the tf.cond()
by simply using the signal apply
, given by _can_apply_on_next_step()
, as a multiplication factor and apply zero-gradients whenever we are in the accumulation-phase.
The idea would be to always accumulate gradients and always apply them with one particular change:
final_gradients = [grad * apply for grad in gradients]
self._orig_apply_gradients(zip(final_gradients, trainable_variables))
This is how we'd change the apply_gradients()
method:
def apply_gradients(self, grads_and_vars, *args, **kwargs):
can_apply = self._can_apply_on_next_step()
# 1.0 whenever we want to apply gradients; 0.0 otherwise
apply = tf.cast(can_apply, dtype=self.variable_dtype)
# Will be 0.0 if apply is 1.0 and vice versa
keep = tf.cast(tf.logical_not(can_apply), dtype=self.variable_dtype)
grads_and_vars = list(grads_and_vars)
gradients = [grad for (grad, _) in grads_and_vars]
trainable_variables = [var for (_, var) in grads_and_vars]
# Accumulate gradients
for i, grad in enumerate(gradients):
self.accu_gradients[i].assign_add(grad)
# Multiply each gradient with our apply-signal
final_gradients = [grad * apply for grad in self.accu_gradients]
self._orig_apply_gradients(zip(final_gradients, trainable_variables), *args, **kwargs)
# This will reset our buffer whenever "keep" is 0.0
for g in self.accu_gradients:
g.assign(g * keep)
But the problem is that self.accu_gradients[i].assign_add(grad)
does not seem to have any effect. And yes, I have also tried
self.accu_gradients[i].assign(grad + self.accu_gradients[i])
Interestingly, the model starts to converge if I use assign(grad)
instead as in self.accu_gradients[i].assign_add(grad)
as you can see:
blue: just using assign() # <- no accumulation happening
red: using assign_add()
The train_step()
This patch should work model independently. I do have a custom train_step()
for my model though but the implementation is pretty straight forward.
Here I am just computing the gradients
and then all the apply_gradients()
method of the optimizer:
def train_step(self, data):
(inputs, (input_lengths, label_lengths), mask), y_true = data
loss, gradients = self.rnnt_gradient(
inputs=inputs,
y_true=y_true,
input_lengths=input_lengths,
label_lengths=label_lengths,
mask=mask
)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {'loss': loss}
def test_step(self, data):
(inputs, (input_lengths, label_lengths), mask), y_true = data
val_loss = self.rnnt_loss_wrapper(
inputs=inputs,
y_true=y_true,
input_lengths=input_lengths,
label_lengths=label_lengths,
mask=mask
)
return dict(loss=val_loss)
def rnnt_gradient(
self,
inputs: tuple,
y_true: tf.Tensor,
input_lengths: tf.Tensor,
label_lengths: tf.Tensor,
mask=None
):
with tf.GradientTape() as tape:
model_loss = self.rnnt_loss_wrapper(
inputs,
y_true=y_true,
input_lengths=input_lengths,
label_lengths=label_lengths,
mask=mask
)
is_mixed_precision = isinstance(self.optimizer, mixed_precision.LossScaleOptimizer)
# We always want to return the unmodified model_loss for Tensorboard
if is_mixed_precision:
loss = self.optimizer.get_scaled_loss(model_loss)
else:
loss = model_loss
gradients = tape.gradient(loss, self.trainable_variables)
if is_mixed_precision:
gradients = self.optimizer.get_unscaled_gradients(gradients)
return model_loss, gradients