0

I'm trying to write a custom loss function for a Keras neural network model. The custom loss will combine a standard loss (e.g. MSE) with a penalty function. The penalty function is intended to verify that the predictions of the model conform to a pre-defined form/shape. In the literature this is called a Physics Informed Neural Network (PINN). See examples of this here and here.

The issue is that I need the model itself and the input data to compute the desired custom penalty; the penalty is NOT based on the difference between the actual and predicted. I found this SO question/answer that would allow me to get the input data into the loss function. However, I don't see any way for me to access the model itself while calculating a custom loss. Is this possible? If so, how?

I would like to be able to do something like this, but obviously this won't work:

def custom_loss(data, y_pred, model):
    y_true = data[:, 0]
    input_data = data[:, 1:]        
    mse = K.mean(K.square(y_pred - y_true), axis=-1)
    penalty = calculate_custom_penalty(input_data, model)
    return mse + penalty

Per Keras documentation model.compile() found in GitHub, it looks like a "loss function is any callable with the signature loss = fn(y_true, y_pred), where y_true are the ground truth values, and y_pred are the model's predictions."

Is there a hack, or trick, to get around this requirement and include the model in the loss function?

Update

I thought I found a way to make this possible using a custom callback class. Using a custom callback, I was able to get a custom_loss to be calculated at the end of each batch with code like this:

def custom_loss(y_true, y_pred, model):
    # Access the model here using the 'model' variable
    # Perform your calculations and return the loss value
    return tf.reduce_mean(tf.square(y_true - y_pred)) + 1.0


# Create a custom callback that passes the model to the loss function
class CustomModelCallback(Callback):
    def __init__(self, **kwargs):
        super().__init__()
        self.train_data = kwargs.get("train_data", None)
        self.validation_data = kwargs.get("validation_data", None)

    def on_train_begin(self, logs=None):
        self.model.custom_loss = custom_loss_wrapper(self.model)

    def on_train_batch_end(self, batch, logs=None):
        # Update the loss here using self.model.custom_loss
        # Get the true labels and predictions for the current batch
        y_true = self.train_data[1]
        y_pred = self.model.predict(self.train_data[0])

        # Calculate the custom loss for the current batch
        self.model.custom_loss = custom_loss_wrapper(self.model)
        custom_loss_value = self.model.custom_loss(y_true, y_pred)

        # Update the logs dictionary with the custom loss value
        logs['custom_loss'] = custom_loss_value.numpy()


def custom_loss_wrapper(model: Model):
    def custom_loss_with_model(y_true, y_pred):
        return custom_loss(y_true, y_pred, model)

When I train with this code, I get an output like this

288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
5/5 [==============================] - 53s 4s/step - loss: 0.2978 - mse: 0.2978 - custom_loss: 1.2051

This basically seems to just put the custom loss metric in the logs. Also, oddly, the MSE and the custom_loss don't match (when accounting for the + 1.0 in the custom_loss).

When I tried just overwriting the loss with the custom_loss value in the logs like this logs['loss'] = custom_loss_value.numpy(), I got this output, that seems to completely ignore my custom_loss (can't see the extra 1 in the loss).

Backend TkAgg is interactive backend. Turning interactive mode on.
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
5/5 [==============================] - 13s 2s/step - loss: 0.2985 - mse: 0.2985

I need the model to use this custom loss DURING back propagation. It does not appear to be doing that, but instead it's just reporting an additional metric. How can I get this custom_loss to be used during back propagation?

Jed
  • 1,823
  • 4
  • 20
  • 52

0 Answers0