I am trying to calculate the recall in both binary and multi class (one hot encoded) classification scenarios for each class after each epoch in a model that uses Tensorflow 2's Keras API. e.g. for binary classification I'd like to be able to do something like
import tensorflow as tf
model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(1))
model.compile(metrics=[binary_recall(label=0), binary_recall(label=1)], ...)
history = model.fit(...)
plt.plot(history.history['binary_recall_0'])
plt.plot(history.history['binary_recall_1'])
plt.show()
or in a multi class scenario I'd like to do something like
model = tf.keras.Sequential()
model.add(...)
model.add(tf.keras.layers.Dense(3))
model.compile(metrics=[recall(label=0), recall(label=1), recall(label=2)], ...)
history = model.fit(...)
plt.plot(history.history['recall_0'])
plt.plot(history.history['recall_1'])
plt.plot(history.history['recall_2'])
plt.show()
I'm working on a classifier for an unbalanced dataset and want to be able to see at what point the recall for my minority class(s) starts to degrade.
I found an implementation of precision for a specific class in a multi-class classifier here https://stackoverflow.com/a/41717938/373655. I'm am trying to adapt this into what I need but keras.backend
is still pretty foreign to me so any help would be greatly appreciated.
I am also not clear on if I can use Keras metrics
(as they are calculated at the end of each batch and then averaged) or if I need to use Keras callbacks
(which can run at the end of each epoch). It seems to me like it shouldn't make a difference for recall (e.g. 8/10 == (3/5 + 5/5) / 2
) but this is why recall was removed in Keras 2 so maybe I'm missing something (https://github.com/keras-team/keras/issues/5794)
Edit - partial solution (multi-class classification) @mujjiga's solution works for both binary classification and multi-class classification but as @P-Gn pointed out, tensorflow 2's Recall metric supports this out of the box for multi-class classification. e.g.
from tensorflow.keras.metrics import Recall
model = ...
model.compile(loss='categorical_crossentropy', metrics=[
Recall(class_id=0, name='recall_0')
Recall(class_id=1, name='recall_1')
Recall(class_id=2, name='recall_2')
])
history = model.fit(...)
plt.plot(history.history['recall_2'])
plt.plot(history.history['val_recall_2'])
plt.show()