0

I am building a CNN classifier for a multi-class classification task (num_classes=7). Due to imbalance and the subject area, my target metric for this task is macro-average recall across the classes.

As the model trains, I would like to checkpoint it by saving the model at the end of each epoch if the validation multi-class macro-recall has been evaluated to be higher than the previously highest value seen throughout the epochs. I believe that this will work in two stages:

  1. Creating a custom metric for calculating the average recall across the classes for a multi-class scenario on the validation data at the end of each epoch
  2. Creating a ModelCheckpoint callback that tracks the custom metric and saves the model if it has exceeded the previous max.

Would anyone have examples of this or similar? I am more interested in the implementation of the custom metric for macro average multi-class recall as I believe the callback can be easily done once this metric is defined in model.compile()

cian
  • 51
  • 3

1 Answers1

0

I implemented the custom metric by using this post with a few tweaks e.g. the running mean was calculated. Below is the code for the custom metric:

import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric

class MacroAverageRecall( Metric ):
    """Custom metric for calculating multiclass recall during         
training"""
    def __init__(self,
                 num_classes,
                 batch_size,
                 name='multiclass_recall',
                 **kwargs):
        super( MacroAverageRecall, self ).__init__( name=name, **kwargs )
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.num_batches = 0
        self.average_recall = self.add_weight( name="recall", initializer="zeros" )

    def update_state(self, y_true, y_pred, sample_weight=None):
        recall = 0
        pred = K.argmax( y_pred, axis=-1 )
        true = K.argmax( y_true, axis=-1 )

        for i in range( self.num_classes ):
            # Find where the pred equals the class
            predicted_instances_bool = K.equal(
                pred,
                i
            )
            # Find where the labels equals the class
            true_instances_bool = K.equal(
                true,
                i
            )
            # Converting tensors of bools to int (1,0)
            predicted_instances = K.cast(
                predicted_instances_bool,
                'float32'
            )
            true_instances = K.cast(
                true_instances_bool,
                'float32'
            )
            # Reshaping tensors
            true_reshaped = K.reshape(
                true_instances,
                (1, -1)
            )
            predicted_reshaped = K.reshape(
                predicted_instances,
                (-1, 1)
            )
            # Find true positives
            true_positives = K.dot(
                true_reshaped,
                predicted_reshaped
            )
            # Compute the true positive
            pred_true_pos = K.sum(
                true_positives
            )
            # divide by all positives in t
            all_true_positives = (K.sum( true_instances ) + K.epsilon())
            class_recall = pred_true_pos / all_true_positives
            recall += class_recall

        self.num_batches += 1
        avg_recall = recall / self.num_classes
        recall_update = (avg_recall - self.average_recall) / self.num_batches
        self.average_recall.assign_add( recall_update )

    def result(self):
        return self.average_recall

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.average_recall.assign( 0. )

And the checkpoint used during model training:

callbacks.ModelCheckpoint(
            filepath=os.path.join(
                self._metadata['checkpoint_directory'],
                f'checkpoint-{self._metadata["create_time"]}.h5' ),
            save_best_only=True if self._val else False,
            monitor='val_multiclass_recall',
            mode='max',
            verbose=1 )
cian
  • 51
  • 3