the tf.keras.Model I am training has the following primary performance indicators:
- escape rate: (#samples with predicted label 0 AND true label 1) / (#samples with true label 1)
- false call rate: (#samples with predicted label 1 AND true label 0) / (#samples with true label 0)
The targeted escape rate is predefined, which means the decision threshold will have to be set appropriately. To calculate the resulting false call rate, I would like to implement a custom metric somewhere along the lines of the following pseudo code:
# separate predicted probabilities by their true label
all_ok_probabilities = all_probabilities.filter(true_label == 0)
all_nok_probabilities = all_probabilities.filter(true_label == 1)
# sort NOK samples
sorted_nok_probabilities = all_nok_probabilities.sort(ascending)
# determine decision threshold
threshold_idx = round(target_escape_rate * num_samples) - 1
threshold = sorted_nok_probabilities[threshold_idx]
# calculate false call rate
false_calls = count(all_ok_probabilities > threshold)
false_call_rate = false_calls / num_ok_samples
My issue is that, in a MirroredStrategy environment, tf.keras automatically distributes metric calculation across all replicas, each of them getting (batch_size / n_replicas) samples per update, and finally sums the results. My algorithm however only works correctly if ALL labels & predictions are combined (final summing could probably be overcome by dividing by the number of replicas).
My idea is to concatenate all y_true
and y_pred
in my metric's update_state()
method into sequences, and running the evaluation in result()
. The first step already seems impossible, however; tf.Variable
only provides suitable aggregation methods for numeric scalars, not for sequences: tf.VariableAggregation.ONLY_FIRST_REPLICA makes me loose all data from 2nd to nth replica, SUM silently locks up the fit()
call, MEAN does not make any sense in my application (and might hang just as well).
I already tried to instantiate the metric outside of the MirroredStrategy scope, but tf.keras.Model.compile()
does not accept that.
Any hints/ideas?
P.S.: Let me know if you need a minimal code example, I am working on it. :)