1

What is the correct way to add a custom training loop to Keras in tensorflow 2, but with V1 compatibility? To explain why this is needed, I am familiar with overloading the train_step() method in modern Keras models. However, I am working on a project I started prior to tensorflow 2 which doesn't support that method. I was able to upgrade and get my code working again with the new version. However, I ran into serious performance and memory issues related to the following questions:

I tried all of the tips suggested in these questions and elsewhere, but I don't achieve the same performance as when I run my code in a compatibility mode. I do this with

tf.compat.v1.disable_eager_execution()

The difference is a factor of two in performance and a memory-leak-like symptom that causes me to run out of RAM (I am running on CPU). I really do need to use the compatibility mode. Unfortunately, when I use the compatibility mode in tensorflow 2, the model no longer calls train_step() in my tf.keras.Model object and it doesn't use my custom training.

This leads me to ask: how can I implement custom training in a tensorflow 1 compatible Keras model? Specifically, the type of custom training that I want to do is add a soft constraint where I evaluate points in the problem domain and evaluate an additional loss term. The points should be randomly chosen and don't need to be in the training set. This looks like the following.

def train_step(self, data):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
    x, y = data

    # Make inputs for the soft constraint
    b = self.bounds  # Numpy array defining boundaries of the input variables
    x0 = (np.random.random((b.shape[1], self.n)) * (b[1] - b[0])[:, None] + b[0][:, None]).T

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Calculate constraint loss
        loss += self.lambda * constraint(self(x0, training=True))

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y, y_pred)

    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

I have already looked into loss layers and additional loss functions, but these don't seem to let me evaluate the model at arbitrary extra points.

1 Answers1

0

I guess your problem of memory is not directly related to the back-compatibility with tensorflow 1, but with a known memory leak problem of tensorflow 2: see for instance these link1 and link2.

The workaround is, at the end of each training session for hyper-parameters search, to clear the tensorflow session and then recompile the model again.

import gc
from tensorflow.keras import backend as K
...
K.clear_session()
gc.collect()
Luca
  • 169
  • 8