What is the correct way to add a custom training loop to Keras in tensorflow 2, but with V1 compatibility? To explain why this is needed, I am familiar with overloading the train_step()
method in modern Keras models. However, I am working on a project I started prior to tensorflow 2 which doesn't support that method. I was able to upgrade and get my code working again with the new version. However, I ran into serious performance and memory issues related to the following questions:
- Keras: Out of memory when doing hyper parameter grid search
- Keras occupies an indefinitely increasing amount of memory for each epoch
I tried all of the tips suggested in these questions and elsewhere, but I don't achieve the same performance as when I run my code in a compatibility mode. I do this with
tf.compat.v1.disable_eager_execution()
The difference is a factor of two in performance and a memory-leak-like symptom that causes me to run out of RAM (I am running on CPU). I really do need to use the compatibility mode. Unfortunately, when I use the compatibility mode in tensorflow 2, the model no longer calls train_step()
in my tf.keras.Model
object and it doesn't use my custom training.
This leads me to ask: how can I implement custom training in a tensorflow 1 compatible Keras model? Specifically, the type of custom training that I want to do is add a soft constraint where I evaluate points in the problem domain and evaluate an additional loss term. The points should be randomly chosen and don't need to be in the training set. This looks like the following.
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
# Make inputs for the soft constraint
b = self.bounds # Numpy array defining boundaries of the input variables
x0 = (np.random.random((b.shape[1], self.n)) * (b[1] - b[0])[:, None] + b[0][:, None]).T
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Calculate constraint loss
loss += self.lambda * constraint(self(x0, training=True))
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
I have already looked into loss layers and additional loss functions, but these don't seem to let me evaluate the model at arbitrary extra points.