I'm trying to write a custom loss function for a Keras neural network model. The custom loss will combine a standard loss (e.g. MSE) with a penalty function. The penalty function is intended to verify that the predictions of the model conform to a pre-defined form/shape. In the literature this is called a Physics Informed Neural Network (PINN). See examples of this here and here.
The issue is that I need the model itself and the input data to compute the desired custom penalty; the penalty is NOT based on the difference between the actual and predicted. I found this SO question/answer that would allow me to get the input data into the loss function. However, I don't see any way for me to access the model itself while calculating a custom loss. Is this possible? If so, how?
I would like to be able to do something like this, but obviously this won't work:
def custom_loss(data, y_pred, model):
y_true = data[:, 0]
input_data = data[:, 1:]
mse = K.mean(K.square(y_pred - y_true), axis=-1)
penalty = calculate_custom_penalty(input_data, model)
return mse + penalty
Per Keras documentation model.compile()
found in GitHub, it looks like a "loss function is any callable with the signature loss = fn(y_true, y_pred)
, where y_true
are the ground truth values, and y_pred
are the model's predictions."
Is there a hack, or trick, to get around this requirement and include the model in the loss function?
Update
I thought I found a way to make this possible using a custom callback class. Using a custom callback, I was able to get a custom_loss
to be calculated at the end of each batch with code like this:
def custom_loss(y_true, y_pred, model):
# Access the model here using the 'model' variable
# Perform your calculations and return the loss value
return tf.reduce_mean(tf.square(y_true - y_pred)) + 1.0
# Create a custom callback that passes the model to the loss function
class CustomModelCallback(Callback):
def __init__(self, **kwargs):
super().__init__()
self.train_data = kwargs.get("train_data", None)
self.validation_data = kwargs.get("validation_data", None)
def on_train_begin(self, logs=None):
self.model.custom_loss = custom_loss_wrapper(self.model)
def on_train_batch_end(self, batch, logs=None):
# Update the loss here using self.model.custom_loss
# Get the true labels and predictions for the current batch
y_true = self.train_data[1]
y_pred = self.model.predict(self.train_data[0])
# Calculate the custom loss for the current batch
self.model.custom_loss = custom_loss_wrapper(self.model)
custom_loss_value = self.model.custom_loss(y_true, y_pred)
# Update the logs dictionary with the custom loss value
logs['custom_loss'] = custom_loss_value.numpy()
def custom_loss_wrapper(model: Model):
def custom_loss_with_model(y_true, y_pred):
return custom_loss(y_true, y_pred, model)
When I train with this code, I get an output like this
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
5/5 [==============================] - 53s 4s/step - loss: 0.2978 - mse: 0.2978 - custom_loss: 1.2051
This basically seems to just put the custom loss metric in the logs. Also, oddly, the MSE and the custom_loss don't match (when accounting for the + 1.0
in the custom_loss).
When I tried just overwriting the loss
with the custom_loss
value in the logs like this logs['loss'] = custom_loss_value.numpy()
, I got this output, that seems to completely ignore my custom_loss (can't see the extra 1
in the loss).
Backend TkAgg is interactive backend. Turning interactive mode on.
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
288/288 [==============================] - 1s 3ms/step
5/5 [==============================] - 13s 2s/step - loss: 0.2985 - mse: 0.2985
I need the model to use this custom loss DURING back propagation. It does not appear to be doing that, but instead it's just reporting an additional metric. How can I get this custom_loss
to be used during back propagation?