8

I am implementing a CNN for an highly unbalanced classification problem and I would like to implement custum metrics in tensorflow to use the Select Best Model callback. Specifically I would like to implement the balanced accuracy score, which is the average of the recall of each class (see sklearn implementation here), does someone know how to do it?

desertnaut
  • 57,590
  • 26
  • 140
  • 166

5 Answers5

9

I was facing the same issue so I implemented a custom class based off SparseCategoricalAccuracy:

class BalancedSparseCategoricalAccuracy(keras.metrics.SparseCategoricalAccuracy):
    def __init__(self, name='balanced_sparse_categorical_accuracy', dtype=None):
        super().__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_flat = y_true
        if y_true.shape.ndims == y_pred.shape.ndims:
            y_flat = tf.squeeze(y_flat, axis=[-1])
        y_true_int = tf.cast(y_flat, tf.int32)

        cls_counts = tf.math.bincount(y_true_int)
        cls_counts = tf.math.reciprocal_no_nan(tf.cast(cls_counts, self.dtype))
        weight = tf.gather(cls_counts, y_true_int)
        return super().update_state(y_true, y_pred, sample_weight=weight)

The idea is to set each class weight inversely proportional to its size.

This code produces some warnings from Autograph but I believe those are Autograph bugs, and the metric seems to work fine.

Aaron Keesing
  • 1,277
  • 10
  • 18
3

There are 3 ways I can think of tackling the situation :-

1)Random Under-sampling - In this method you can randomly remove samples from the majority classes.

2)Random Over-sampling - In this method you can increase the samples by replicating them.

3)Weighted cross entropy - You can also use weighted cross entropy so that the loss value can be compensated for the minority classes. See here

I have personally tried method2 and it does increase my accuracy by significant value but it may vary from dataset to dataset

2

NOTE

It appears that the implementation/API of the Recall class, which I used as a template for my answer, has been modified in the newer TF versions (as pointed out by @guilaumme-gaudin), so I recommend you look at the Recall implementation used in your current TF version and take it from there to implement the metric using the same approach I describe in the original post, this way I don't have to update my answer every time the TF team modifies the implementation/API of its metrics.

Original post

I'm not an expert in Tensorflow but using a bit of pattern matching between metrics implementations in the tf source code I came up with this

from tensorflow.python.keras import backend as K
from tensorflow.python.keras.metrics import Metric
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.keras.utils.generic_utils import to_list

class BACC(Metric):

    def __init__(self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None):
        super(BACC, self).__init__(name=name, dtype=dtype)
        self.init_thresholds = thresholds
        self.top_k = top_k
        self.class_id = class_id

        default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
        self.thresholds = metrics_utils.parse_init_thresholds(
            thresholds, default_threshold=default_threshold)
        self.true_positives = self.add_weight(
            'true_positives',
            shape=(len(self.thresholds),),
            initializer=init_ops.zeros_initializer)
        self.true_negatives = self.add_weight(
            'true_negatives',
            shape=(len(self.thresholds),),
            initializer=init_ops.zeros_initializer)
        self.false_positives = self.add_weight(
            'false_positives',
            shape=(len(self.thresholds),),
            initializer=init_ops.zeros_initializer)
        self.false_negatives = self.add_weight(
            'false_negatives',
            shape=(len(self.thresholds),),
            initializer=init_ops.zeros_initializer)

    def update_state(self, y_true, y_pred, sample_weight=None):

        return metrics_utils.update_confusion_matrix_variables(
            {
                metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
                metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
                metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
                metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
            },
            y_true,
            y_pred,
            thresholds=self.thresholds,
            top_k=self.top_k,
            class_id=self.class_id,
            sample_weight=sample_weight)

    def result(self):
        """
        Returns the Balanced Accuracy (average between recall and specificity)
        """
        result = (math_ops.div_no_nan(self.true_positives, self.true_positives + self.false_negatives) +
                  math_ops.div_no_nan(self.true_negatives, self.true_negatives + self.false_positives)) / 2
        
        return result

    def reset_states(self):
        num_thresholds = len(to_list(self.thresholds))
        K.batch_set_value(
            [(v, np.zeros((num_thresholds,))) for v in self.variables])

    def get_config(self):
        config = {
            'thresholds': self.init_thresholds,
            'top_k': self.top_k,
            'class_id': self.class_id
        }
        base_config = super(BACC, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

I've simply taken the Recall class implementation from the source code as a template and I extended it to make sure it has a TP,TN,FP and FN defined.

After that I modified the result method so that it calculates balanced accuracy and voila :)

I compared the results from this with sklearn's balanced accuracy score and the values matched so I think it's correct, but do double check just in case.

Gabriel Cia
  • 388
  • 5
  • 13
  • 1
    There is an error in the code: `K.batch_set_value( [(v, np.zeros((num_thresholds,))) for v in self.variables])` you must replace `self.variables` by `(self.true_positives,self.true_negatives, self.false_negatives, self.false_positives)` it's like in Recall source metric function https://github.com/keras-team/keras/blob/v2.8.0/keras/metrics.py#L1439-L1565. If not you don't compute the correct values. – guillaume godin Apr 30 '22 at 04:43
1

I have not tested this code yet, but looking at the source code of tensorflow==2.1.0, this might work for the binary classification case:

from tensorflow.keras.metrics import Recall
from tensorflow.python.ops import math_ops


class BalancedBinaryAccuracy(Recall):
    def result(self):
        result = (math_ops.div_no_nan(self.true_positives, self.true_positives + self.false_negatives) +
                  math_ops.div_no_nan(self.true_negatives, self.true_negatives + self.false_positives)) / 2
        return result[0] if len(self.thresholds) == 1 else result
Alexandre Huat
  • 806
  • 10
  • 16
0

As an alternative to writing a custom metric, you can write a custom callback using the metrics already implemented ad available via the training logs. For example you can define the training balanced accuracy callback like this:

class TrainBalancedAccuracyCallback(tf.keras.callbacks.Callback):

    def __init__(self, **kargs):
        super(TrainBalancedAccuracyCallback, self).__init__(**kargs)

    def on_epoch_end(self, epoch, logs={}):

        train_sensitivity = logs['tp'] / (logs['tp'] + logs['fn'])
        train_specificity = logs['tn'] / (logs['tn'] + logs['fp'])
        logs['train_sensitivity'] = train_sensitivity
        logs['train_specificity'] = train_specificity
        logs['train_balacc'] = (train_sensitivity + train_specificity) / 2
        print('train_balacc', logs['train_balacc'])

and the same for the validation:

class ValBalancedAccuracyCallback(tf.keras.callbacks.Callback):

    def __init__(self, **kargs):
        super(ValBalancedAccuracyCallback, self).__init__(**kargs)

    def on_epoch_end(self, epoch, logs={}):

        val_sensitivity = logs['val_tp'] / (logs['val_tp'] + logs['val_fn'])
        val_specificity = logs['val_tn'] / (logs['val_tn'] + logs['val_fp'])
        logs['val_sensitivity'] = val_sensitivity
        logs['val_specificity'] = val_specificity
        logs['val_balacc'] = (val_sensitivity + val_specificity) / 2
        print('val_balacc', logs['val_balacc'])

and then you can use these as values to the callback argument of the fit method of the model.

Aelius
  • 1,029
  • 11
  • 22