0

I would like to implement a model checkpoint callback based on balanced accuracy score. For this, I implemented following class:

class BalAccScore(keras.callbacks.Callback):

    def __init__(self, validation_data=None):
        super(BalAccScore, self).__init__()
        self.validation_data = validation_data
        
    def on_train_begin(self, logs={}):
      self.balanced_accuracy = []

    def on_epoch_end(self, epoch, logs={}):
        y_predict = tf.argmax(self.model.predict(self.validation_data[0]), axis=1)
        y_true = tf.argmax(self.validation_data[1], axis=1)
        balacc = balanced_accuracy_score(y_true, y_predict)
        self.balanced_accuracy.append(round(balacc,6))
        logs["val_bal_acc"] = balacc
        keys = list(logs.keys())

        print("\n ------ validation balanced accuracy score: %f ------\n" %balacc)

I then define following callbacks

balAccScore = BalAccScore(validation_data=(X_2, y_2))
mc = ModelCheckpoint(filepath=callback_path, monitor="val_bal_acc", verbose=1, save_best_only=True, save_freq='epoch')

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=['val_bal_acc'])

history = model.fit(X_1, y_1, epochs = 5, batch_size = 512,
                    callbacks=[balAccScore,  mc],
                    validation_data = (X_2, y_2)
                    )

I then get the error

ValueError: Unknown metric function: val_bal_acc

despite the fact that I find it under history when using for example accuracy instead, i.e. by setting metrics=["acc"] when compiling instead. In which case, I get the to be expected warning:

WARNING:tensorflow:Can save best model only with val_bal_acc available, skipping.

but otherwise the model runs perfectly. Not sure why it is not running otherwise.

Strickland
  • 590
  • 4
  • 14

2 Answers2

1

you should just remove the quotations in compile :

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=[val_bal_acc])

or at least this how it works in R

mosc9575
  • 5,618
  • 2
  • 9
  • 32
Doha Naga
  • 11
  • 1
0

You're getting that error because you're not passing the balanced_accuracy_score as a value for the metrics argument when compiling the model. The string 'val_bal_acc' that you passed in the metrics argument when compiling the model doesn't work because it's not a known metric. You can access metrics by their string name, only for those metrics already implemented in tf.keras.metrics. If you want to monitor the validation balanced accuracy during training you should implement a custom metric class (you can look here how to do it) and then pass it to the metrics argument. Once you've done this you can monitor your custom metric using the name you gave to it with the prefix 'val', if you want to monitor it during the validation time. There's no need for a supplementary custom callback as you did, the logs are updated automatically once you've defined the metric. For this particular case, you can find some implementation of this metric in the answers to this question.

If you prefer a callback approach instead you don't need to define a custom metric but take advantage of the already logged metrics. You can find an implementation of that in my answer here.

Aelius
  • 1,029
  • 11
  • 22