I would like to implement a model checkpoint callback based on balanced accuracy score. For this, I implemented following class:
class BalAccScore(keras.callbacks.Callback):
def __init__(self, validation_data=None):
super(BalAccScore, self).__init__()
self.validation_data = validation_data
def on_train_begin(self, logs={}):
self.balanced_accuracy = []
def on_epoch_end(self, epoch, logs={}):
y_predict = tf.argmax(self.model.predict(self.validation_data[0]), axis=1)
y_true = tf.argmax(self.validation_data[1], axis=1)
balacc = balanced_accuracy_score(y_true, y_predict)
self.balanced_accuracy.append(round(balacc,6))
logs["val_bal_acc"] = balacc
keys = list(logs.keys())
print("\n ------ validation balanced accuracy score: %f ------\n" %balacc)
I then define following callbacks
balAccScore = BalAccScore(validation_data=(X_2, y_2))
mc = ModelCheckpoint(filepath=callback_path, monitor="val_bal_acc", verbose=1, save_best_only=True, save_freq='epoch')
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['val_bal_acc'])
history = model.fit(X_1, y_1, epochs = 5, batch_size = 512,
callbacks=[balAccScore, mc],
validation_data = (X_2, y_2)
)
I then get the error
ValueError: Unknown metric function: val_bal_acc
despite the fact that I find it under history when using for example accuracy instead, i.e. by setting metrics=["acc"] when compiling instead. In which case, I get the to be expected warning:
WARNING:tensorflow:Can save best model only with val_bal_acc available, skipping.
but otherwise the model runs perfectly. Not sure why it is not running otherwise.