0

the tf.keras.Model I am training has the following primary performance indicators:

  • escape rate: (#samples with predicted label 0 AND true label 1) / (#samples with true label 1)
  • false call rate: (#samples with predicted label 1 AND true label 0) / (#samples with true label 0)

The targeted escape rate is predefined, which means the decision threshold will have to be set appropriately. To calculate the resulting false call rate, I would like to implement a custom metric somewhere along the lines of the following pseudo code:

# separate predicted probabilities by their true label
all_ok_probabilities = all_probabilities.filter(true_label == 0)
all_nok_probabilities = all_probabilities.filter(true_label == 1)

# sort NOK samples
sorted_nok_probabilities = all_nok_probabilities.sort(ascending)

# determine decision threshold
threshold_idx = round(target_escape_rate * num_samples) - 1
threshold = sorted_nok_probabilities[threshold_idx]

# calculate false call rate
false_calls = count(all_ok_probabilities > threshold)
false_call_rate = false_calls / num_ok_samples

My issue is that, in a MirroredStrategy environment, tf.keras automatically distributes metric calculation across all replicas, each of them getting (batch_size / n_replicas) samples per update, and finally sums the results. My algorithm however only works correctly if ALL labels & predictions are combined (final summing could probably be overcome by dividing by the number of replicas).

My idea is to concatenate all y_true and y_pred in my metric's update_state() method into sequences, and running the evaluation in result(). The first step already seems impossible, however; tf.Variable only provides suitable aggregation methods for numeric scalars, not for sequences: tf.VariableAggregation.ONLY_FIRST_REPLICA makes me loose all data from 2nd to nth replica, SUM silently locks up the fit() call, MEAN does not make any sense in my application (and might hang just as well).

I already tried to instantiate the metric outside of the MirroredStrategy scope, but tf.keras.Model.compile() does not accept that.

Any hints/ideas?

P.S.: Let me know if you need a minimal code example, I am working on it. :)

1 Answers1

0

Solved myself by implementing it as callback instead of metric. I run fit() without "validation_data" and instead have all validation set metrics calculated in the callback. This avoids two redundant validation set predictions.

In order to inject the resulting metric values back into the training procedure, I used the rather hackish approach from Access variables of caller function in Python.

class ValidationCallback(tf.keras.callbacks.Callback):
    """helper class to calculate validation set metrics after each epoch"""

    def __init__(self, val_data, escape_rate, **kwargs):
        # call parent constructor
        super(ValidationCallback, self).__init__(**kwargs)

        # save parameters
        self.val_data = val_data
        self.escape_rate = escape_rate

        # declare batch_size - we will get that later
        self.batch_size = 0

    def on_epoch_end(self, epoch, logs=None):
        # initialize empty arrays
        y_pred = np.empty((0,2))
        y_true = np.empty(0)

        # iterate over validation set batches
        for batch in self.val_data:
            # save batch size, if not yet done
            if self.batch_size == 0:
                self.batch_size = batch[1].shape[0]

            # concat all batch labels & predictions
            # need to do predict()[0] due to several model outputs
            y_pred = np.concatenate([y_pred, self.model.predict(batch[0])[0]], axis=0)
            y_true = np.concatenate([y_true, batch[1]], axis=0)

        # calculate classical accuracy for threshold 0.5
        acc = ((y_pred[:, 1] >= 0.5) == y_true).sum() / y_true.shape[0]

        # calculate cross-entropy loss
        cce = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM)
        loss = cce(y_true, y_pred).numpy() / self.batch_size

        # caculate false call rate
        y_pred_nok = np.sort(y_pred[y_true == 1, 1])
        idx = int(np.round(self.escape_rate * y_pred_nok.shape[0]))
        threshold = y_pred_nok[idx]
        false_calls = y_pred[(y_true == 0) & (y_pred[:, 1] >= threshold), 1].shape[0]
        fcr = false_calls / y_true[y_true == 0].shape[0]

        # add metrics to 'logs' dict of our caller (tf.keras.callbacks.CallbackList.on_epoch_end()),
        # so that they become available to following callbacks
        for f in inspect.stack():
            if 'logs' in f[0].f_locals:
                f[0].f_locals['logs'].update({'val_accuracy': acc,
                                              'val_loss': loss,
                                              'val_false_call_rate': fcr})
                return