My question is quite straightforward but I can't find a definite answer online (so far).
I have saved the weights of a keras model trained with an adam optimizer after a defined number of epochs of training using:
callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])
When I resume training after closing my jupyter, can I simply use:
model.load_weights(path)
to continue training.
Since Adam is dependent on the epoch number (such as in the case of learning rate decay), I would like to know the easiest way to resume training in the same conditions as before.
Following ibarrond's answer, I have written a small custom callback.
optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])
weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)
class optim_callback(tf.keras.callbacks.Callback):
'''Custom callback to save optimiser state'''
def on_epoch_end(self,epoch,logs=None):
optim_state = tf.keras.optimizers.Adam.get_config(optim)
with open(optim_state_pkl,'wb') as f_out:
pickle.dump(optim_state,f_out)
model.fit(X,y,callbacks=[weight_callback,optim_callback()])
When I resume training:
model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:
optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)
I would just like to check if this is correct. Many thanks again!!
Addendum: On further reading of the default Keras implementation of Adam and the original Adam paper, I believe that the default Adam is not dependent on epoch number but only on the iteration number. Therefore, this is unnecessary. However, the code may still be useful for anyone who wishes to keep track of other optimisers.