I implemented the custom metric by using this post with a few tweaks e.g. the running mean was calculated. Below is the code for the custom metric:
import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric
class MacroAverageRecall( Metric ):
"""Custom metric for calculating multiclass recall during
training"""
def __init__(self,
num_classes,
batch_size,
name='multiclass_recall',
**kwargs):
super( MacroAverageRecall, self ).__init__( name=name, **kwargs )
self.batch_size = batch_size
self.num_classes = num_classes
self.num_batches = 0
self.average_recall = self.add_weight( name="recall", initializer="zeros" )
def update_state(self, y_true, y_pred, sample_weight=None):
recall = 0
pred = K.argmax( y_pred, axis=-1 )
true = K.argmax( y_true, axis=-1 )
for i in range( self.num_classes ):
# Find where the pred equals the class
predicted_instances_bool = K.equal(
pred,
i
)
# Find where the labels equals the class
true_instances_bool = K.equal(
true,
i
)
# Converting tensors of bools to int (1,0)
predicted_instances = K.cast(
predicted_instances_bool,
'float32'
)
true_instances = K.cast(
true_instances_bool,
'float32'
)
# Reshaping tensors
true_reshaped = K.reshape(
true_instances,
(1, -1)
)
predicted_reshaped = K.reshape(
predicted_instances,
(-1, 1)
)
# Find true positives
true_positives = K.dot(
true_reshaped,
predicted_reshaped
)
# Compute the true positive
pred_true_pos = K.sum(
true_positives
)
# divide by all positives in t
all_true_positives = (K.sum( true_instances ) + K.epsilon())
class_recall = pred_true_pos / all_true_positives
recall += class_recall
self.num_batches += 1
avg_recall = recall / self.num_classes
recall_update = (avg_recall - self.average_recall) / self.num_batches
self.average_recall.assign_add( recall_update )
def result(self):
return self.average_recall
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.average_recall.assign( 0. )
And the checkpoint used during model training:
callbacks.ModelCheckpoint(
filepath=os.path.join(
self._metadata['checkpoint_directory'],
f'checkpoint-{self._metadata["create_time"]}.h5' ),
save_best_only=True if self._val else False,
monitor='val_multiclass_recall',
mode='max',
verbose=1 )