Assuming you use TensorBoard, then you have a historical record—in the form of tfevents files—of all your metric calculations, for all your epochs; then a tf.keras.callbacks.Callback
is what you want.
I use tf.keras.callbacks.ModelCheckpoint
with save_freq: 'epoch'
to save—as an h5 file or tf file—the weights for each epoch.
To avoid filling the hard-drive with model files, write a new Callback
—or extend the ModelCheckpoint
class's—on_epoch_end
implementation:
def on_epoch_end(self, epoch, logs=None):
super(DropWorseModels, self).on_epoch_end(epoch, logs)
if epoch < self._keep_best:
return
model_files = frozenset(
filter(lambda filename: path.splitext(filename)[1] == SAVE_FORMAT_WITH_SEP,
listdir(self._model_dir)))
if len(model_files) < self._keep_best:
return
tf_events_logs = tuple(islice(log_parser(tfevents=path.join(self._log_dir,
self._split),
tag=self.monitor),
0,
self._keep_best))
keep_models = frozenset(map(self._filename.format,
map(itemgetter(0), tf_events_logs)))
if len(keep_models) < self._keep_best:
return
it_consumes(map(lambda filename: remove(path.join(self._model_dir, filename)),
model_files - keep_models))
Appendix (imports and utility function implementations):
from itertools import islice
from operator import itemgetter
from os import path, listdir, remove
from collections import deque
import tensorflow as tf
from tensorflow.core.util import event_pb2
def log_parser(tfevents, tag):
values = []
for record in tf.data.TFRecordDataset(tfevents):
event = event_pb2.Event.FromString(tf.get_static_value(record))
if event.HasField('summary'):
value = event.summary.value.pop(0)
if value.tag == tag:
values.append(value.simple_value)
return tuple(sorted(enumerate(values), key=itemgetter(1), reverse=True))
it_consumes = lambda it, n=None: deque(it, maxlen=0) if n is None \
else next(islice(it, n, n), None)
SAVE_FORMAT = 'h5'
SAVE_FORMAT_WITH_SEP = '{}{}'.format(path.extsep, SAVE_FORMAT)
For completeness, the rest of the class:
class DropWorseModels(tf.keras.callbacks.Callback):
"""
Designed around making `save_best_only` work for arbitrary metrics
and thresholds between metrics
"""
def __init__(self, model_dir, monitor, log_dir, keep_best=2, split='validation'):
"""
Args:
model_dir: directory to save weights. Files will have format
'{model_dir}/{epoch:04d}.h5'.
split: dataset split to analyse, e.g., one of 'train', 'test', 'validation'
monitor: quantity to monitor.
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard.
keep_best: number of models to keep, sorted by monitor value
"""
super(DropWorseModels, self).__init__()
self._model_dir = model_dir
self._split = split
self._filename = 'model-{:04d}' + SAVE_FORMAT_WITH_SEP
self._log_dir = log_dir
self._keep_best = keep_best
self.monitor = monitor
This has the added advantage of being able to save and delete multiple model files in a single Callback. You can easily extend with different thresholding support, e.g., to keep all model files with an AUC in threshold OR TP, FP, TN, FN within threshold.