0

I am currently using the TensorFlow source to save and restore the trained NN model weights:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

I am also familiar with checkpoints during training, but my question is:

Can we save the model/weights locally or globally while we are training the model instead of saving it to the file?

I am using something like grid search but I have a loop that in each iteration, I am training my model partially on some portion of the dataset and then save the trained/learned weights and continue to train/learn on another set of the dataset?

sample pseudo-code of my work:

for i in range(1,10):
    - use dataset A1 for training
    - train model on dataset A1
    - test on the testing dataset X
    - save model weights
    - restore model weights
    - now use dataset A2
    - run model on trained weights to see initial accuracy
    - retrain the model on dataset A2 and keep previously saved weights
    - save model weights
end

I have already looked into the other post like this, but it's not answering my question.

Bilgin
  • 499
  • 1
  • 10
  • 25

1 Answers1

0

yes you can. The way to do it is to create a custom callback. In the callback a class variable called best_weights is created This class variable is updated at the end of each epoch if the validation loss is the lowest loss produced thus far. Below is a the code.

class LRA(keras.callbacks.Callback):
    def __init__(self,model, verbose ):
        super(LRA, self).__init__()
        self.model=model
        self.lowest_vloss=np.inf # set lowest validation loss to infinity        
        best_weights=self.model.get_weights() # set a class vaiable so weights can be loaded after training is completed           
         
    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch
        v_loss=logs.get('val_loss')  # get the validation loss for this epoch
        if v_loss< self.lowest_vloss: # check if the validation loss improved
            if verbose==1:
                msg=f' validation loss improved from {self.lowest_vloss:8.5f} to {v_loss:8.5}, saving best weights' 
                print (msg)
            self.lowest_vloss=v_loss # replace lowest validation loss with new validation loss                
            LRA.best_weights=self.model.get_weights() # validation loss improved so save the weights

in model.fit include callbacks =[LRA(model, verbose=1)]

the callback makes available the class variable LRA.best_weights. It contains the model weights for the epoch that achieved the lowest validation loss. You can use it in for example model.set_weights(LRA.best_weights). In the callback parameter model is your model. Parameter verbose is an integer. If set to 1 then at the end of an epoch if the validation loss has improved a message is printed that the best weights have been saved. If verbose is not = 1 no message is printed.

Gerry P
  • 7,662
  • 3
  • 10
  • 20