After training the model, I use the predict
method to infer the scores from my testing data.
From the predicted scores how can I use the model compiled loss and metrics to calculate the loss of my predictions?
What I have tried
Based on Customizing what happens in fit()
guide I tried using the compiled_loss
method
y_pred = model.predict(x_test)
model.compiled_loss(y_test, y_pred, regularization_losses=model.losses)
But it returns the error
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-6-3eb62dca0b87> in <module>()
1 y_pred = model.predict(x_test)
----> 2 loss = model.compiled_loss(y_test, y_pred)
1 frames
/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py in match_dtype_and_rank(y_t, y_p, sw)
673 def match_dtype_and_rank(y_t, y_p, sw):
674 """Match dtype and rank of predictions."""
--> 675 if y_t.shape.rank == 1 and y_p.shape.rank == 2:
676 y_t = tf.expand_dims(y_t, axis=-1)
677 if sw is not None:
AttributeError: 'tuple' object has no attribute 'rank'
How to reproduce
I used the Simple MNIST convnet example followed by
y_pred = model.predict(x_test)
model.compiled_loss(y_test, y_pred, regularization_losses=model.losses)
to reproduce the error
About my problem
I am validating my data on a custom metric. However some Keras users recommended that global metrics should not be averaged by batch, instead, calculated from the predicted scores for the whole validation data in a Callback.
See:
How to calculate F1 Macro in Keras?
The bad solution to this is to calculate the loss and metrics from the evaluate
method, and my custom metric from predict
. The problem with this is that I am running the inference twice.
A less worse solution is to implement my loss function separately so it can work from the predicted scores.
See:
Calculate loss in Keras without running the model
The issue with this is that it gives me less flexibility to choose loss functions because I have to implement every loss function separately in the Callback later.
But I really wonder, isn't the compiled loss and metrics accessible somewhere already?