20

I am building a custom metric to measure the accuracy of one class in my multi-class dataset during training. I am having trouble selecting the class.

The targets are one hot (e.g: the class 0 label is [1 0 0 0 0]):

from keras import backend as K

def single_class_accuracy(y_true, y_pred):
    idx = bool(y_true[:, 0])              # boolean mask for class 0 
    class_preds = y_pred[idx]
    class_true = y_true[idx]
    class_acc = K.mean(K.equal(K.argmax(class_true, axis=-1), K.argmax(class_preds, axis=-1)))  # multi-class accuracy  
    return class_acc

The trouble is, we have to use Keras functions to index tensors. How do you create a boolean mask for a tensor?

nbro
  • 15,395
  • 32
  • 113
  • 196
Chris Parry
  • 2,937
  • 7
  • 30
  • 71
  • I'm not familiar with Keras and do not know if your code will work with boolean masks or explicit indices. Did you cast your mask to type boolean? tf.cast(binary_mask, tf.bool). With Theano you can use bool_mask.nonzero() to get the indices of the boolean mask. Let us know if this solution works. – rafaelvalle Jan 16 '17 at 23:22
  • Would you accept the answer which is using a callback? – Marcin Możejko Jan 17 '17 at 13:48
  • Just to make sure - y_true is 2D? what does the rows and columns supposed to represent here? – ginge Jan 17 '17 at 15:32

2 Answers2

23

Note that when talking about the accuracy of one class one may refer to either of the following (not equivalent) two amounts:

  • The recall, which, for class C, is the ratio of examples labelled with class C that are predicted to have class C.
  • The precision, which, for class C, is the ratio of examples predicted to be of class C that are in fact labelled with class C.

Instead of doing complex indexing, you can just rely on masking for you computation. Assuming we are talking about precision here (changing to recall would be trivial).

from keras import backend as K

INTERESTING_CLASS_ID = 0  # Choose the class of interest

def single_class_accuracy(y_true, y_pred):
    class_id_true = K.argmax(y_true, axis=-1)
    class_id_preds = K.argmax(y_pred, axis=-1)
    # Replace class_id_preds with class_id_true for recall here
    accuracy_mask = K.cast(K.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32')
    class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
    class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
    return class_acc

If you want to be more flexible, you can also have the class of interest parametrised:

from keras import backend as K

def single_class_accuracy(interesting_class_id):
    def fn(y_true, y_pred):
        class_id_true = K.argmax(y_true, axis=-1)
        class_id_preds = K.argmax(y_pred, axis=-1)
        # Replace class_id_preds with class_id_true for recall here
        accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32')
        class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
        class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
        return class_acc
    return fn

And the use it as:

model.compile(..., metrics=[single_class_accuracy(INTERESTING_CLASS_ID)])
shora
  • 131
  • 11
jdehesa
  • 58,456
  • 7
  • 77
  • 121
0

Just adding an updated version of the code from jdehesa's answer to work with the new Tensorflow API. Please upvote that as well, if you find this helpful.

import tensorflow as tf

INTERESTING_CLASS_ID = 0  # Choose the class of interest

def single_class_accuracy(y_true, y_pred):
    class_id_true = tf.math.argmax(y_true, axis=-1)
    class_id_preds = tf.math.argmax(y_pred, axis=-1)

    # Replace class_id_preds with class_id_true for recall here
    accuracy_mask = tf.cast(tf.math.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32')
    class_acc_tensor = tf.cast(tf.math.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
    class_acc = tf.math.reduce_sum(class_acc_tensor) / tf.math.maximum(tf.math.reduce_sum(accuracy_mask), 1)

    return class_acc
Waylon Flinn
  • 19,969
  • 15
  • 70
  • 72