I want to use tensorflow's custom training loop for my model but, down to memory constraints, I can only pass a small number of samples (mini-batches) through in one go. How do I use an approach to train on these mini-batches and sensibly aggregate the gradients for the whole batch on one machine (GPU/CPU)? See below example with code from here - note this example doesn't hit memory issues based on the batch size but does give the idea of what I'm trying to do:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
#simple MNIST model
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
If training on the full 64 sample batch size in one go could fit in memory we could simply use:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
import time
epochs = 10
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
# Log every 200 batches.
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %d samples" % ((step + 1) * batch_size))
However, how do I update train_step to enable it to take four mini-batch runs of size 16 (for example) to make up the full batch size of 64 to deal with my more memory intensive data and then aggregate the gradients to update the model?
I tried just writing a loop within the with tf.GradientTape() as tape:
call and just stacking the loss results but I don't think this is the correct approach.
I also thought about using tf.distribute.Strategy
but my understanding is this is only for using when training across machines or GPUs so I don't see how I could use it here?
To summarise, What I want to do is agnostic to the dataset and model architecture. I guess I am looking for an Gradient AllReduce approach which in stead of splitting the mini-batches to different machines instead just runs them iteratively. So it would need to:
- Compute the gradient using a minibatch.
- Compute the mean of the gradients from all mini-batches, using a AllReduce collective-style approach.
- Update the model with the averaged gradient.
I assume this approach of applying the mean of the gradients would be far less memory intensive than applying all the gradients as discussed here