The param period
mentioned in the accepted answer is now not available anymore.
Using the save_freq
param is an alternative, but risky, as mentioned in the docs; e.g., if the dataset size changes, it may become unstable: Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (again taken from the docs).
Thus, I use a subclass as a solution:
class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
def __init__(self,
filepath,
frequency=1,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
options=None,
**kwargs):
super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
mode, "epoch", options)
self.epochs_since_last_save = 0
self.frequency = frequency
def on_epoch_end(self, epoch, logs=None):
self.epochs_since_last_save += 1
# pylint: disable=protected-access
if self.epochs_since_last_save % self.frequency == 0:
self._save_model(epoch=epoch, batch=None, logs=logs)
def on_train_batch_end(self, batch, logs=None):
pass
use it as
callbacks=[
EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]
Note that, dependent on your TF version, you may have to change the args in the call to the superclass __init__
.