3

ModelCheckpoint can be used to save the best model based on a specific monitored metrics. So it obviously has information about the best metrics stored within its object. If you train on google colab for example, your instance can be killed without warning and you would lose this info after a long training session.

I tried to pickle the ModelCheckpoint object but got:

TypeError: can't pickle _thread.lock objects  

Such that i can reuse this same object when I bring my notebook back. Is there a good way to do this? You can try to reproduce by:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

with open('chkpt_cb.pickle', 'w') as f:
  pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)
kawingkelvin
  • 3,649
  • 2
  • 30
  • 50
  • Can you post the current code block you're using? ModelCheckpoint is typically a callback so it's unclear how you're using it from your description. – adamconkey Dec 09 '19 at 22:24
  • @adamconkey I have updated this with code to reproduce. It is fairly simple. I just want to pickle the callback object. Based on the error, it must have something to do with thread issue. – kawingkelvin Dec 10 '19 at 21:20
  • Quick ans I found: Pickle chkpt_cb.best, and then reassign it to a new checkpoint. Just tried and it works. – kawingkelvin Dec 10 '19 at 22:57

2 Answers2

5

If callback object is not to be pickled (due to thread issue and not advisable), I can pickle this instead:

best = chkpt_cb.best

This stores the best monitored metrics that callback has seen, and it is a float, which you can pickle and reload next time, and then do this:

chkpt_cb.best = best   # if chkpt_cb is a brand new object you create when colab killed your session. 

This is my own setup:

# All paths should be on Google Drive, I omitted it here for simplicity.

chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

if os.path.exists('chkpt_cb.best.pickle'):
  with open('chkpt_cb.best.pickle', 'rb') as f:
    best = pickle.load(f)
    chkpt_cb.best = best

def save_chkpt_cb():
  with open('chkpt_cb.best.pickle', 'wb') as f:
    pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)

save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)

history = model.fit_generator(generator=train_data_gen,
                          validation_data=dev_data_gen,
                          epochs=5,
                          callbacks=[chkpt_cb, save_chkpt_cb_callback])

So even when your colab session got killed, you can still retrieve the last best metrics and inform your new instance about it, and continue training as usual. This especially help when you re-compile a stateful optimizer and may cause a regression in the loss/metric and don't want to save those models for first few epochs.

kawingkelvin
  • 3,649
  • 2
  • 30
  • 50
3

I think you might be misunderstanding the intended usage of the ModelCheckpoint object. It is a callback that periodically gets called during training at a particular phase. The ModelCheckpoint callback in particular gets called after every epoch (if you keep the default period=1) and saves your model to disk in the filename you specify to the filepath argument. The model is saved in the same way described here. Then if you want to load that model later, you can do something like

from keras.models import load_model
model = load_model('my_model.h5')

Other answers on SO provide nice guidance and examples for continuing training from a saved model, for example: Loading a trained Keras model and continue training. Importantly, the saved H5 file stores everything about your model that is needed to continue training.

As suggested in the Keras documentation, you should not use pickle to serialize your model. Simply register the ModelCheckpoint callback with your 'fit' function:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)
model.fit(x_train, y_train,
          epochs=100,
          steps_per_epoch=5000,
          callbacks=[chkpt_cb])

Your model will be saved in an H5 file named as you have it, with the epoch number and loss values automatically formated for you. For example, your saved file for the 5th epoch with loss 0.0023 would look like model.05-.0023.h5, and since you set save_best_only=True, the model will only be saved if your loss is better than the previously saved one so you don't pollute your directory with a bunch of unneeded model files.

adamconkey
  • 4,104
  • 5
  • 32
  • 61
  • Yes, I understood this is how it should be used. If you have used colab and got cutoff in the middle of training, you will figure out that your last best metric will be forgotten if you reinstantiate the callback from scratch. So I am trying to find the solution where the callback object can persist on disk. It certain does in memory if your notebook session is live. You can run multiple fit(...) and it still tracks the best metrics so far. – kawingkelvin Dec 10 '19 at 22:40
  • I found an answer and posted. The best metrics is stored within the callback object for sure. – kawingkelvin Dec 10 '19 at 22:52