0

I am trying to implement the following function to save the model_state checkpoints:

def train_epoch(self):
for epoch in tqdm.trange(self.epoch, self.max_epoch, desc='Train Epoch', ncols=100):
    self.epoch = epoch      # increments the epoch of Trainer
    checkpoint = {} # fixme: here checkpoint!!!
    # model_save_criteria = self.model_save_criteria
    self.train()
    if epoch % 1 == 0:
        self.validate(checkpoint) 
    checkpoint_latest = {
        'epoch': self.epoch,
        'arch': self.model.__class__.__name__,
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optim.state_dict(),
        'model_save_criteria': self.model_save_criteria
    }
    checkpoint['checkpoint_latest'] = checkpoint_latest
    torch.save(checkpoint, self.model_pth)

Previously I did the same by just running a for loop:

train_states = {}
for epoch in range(max_epochs):
    running_loss = 0
    time_batch_start = time.time()
    model.train()
    for bIdx, sample in enumerate(train_loader):
        ...
        train...
        validation...
        train_states_latest = {
          'epoch': epoch + 1,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'model_save_criteria': chosen_criteria}
        train_states['train_states_latest'] = train_states_latest
        torch.save(train_states, FILEPATH_MODEL_SAVE)

Are there ways to initiate the checkpoint={} and update it every loop? Or checkpoint={} in every epoch is fine since the model itself is holding the state_dict(). Just I am overwriting the checkpoint each time.

banikr
  • 63
  • 1
  • 9

1 Answers1

0

You can avoid overwriting the checkpoint by simply changing the FILEPATH_MODEL_SAVE path and have that path contain info on the epoch or iteration number. For example (taking your original code),

train_states = {}
for epoch in range(max_epochs):
    running_loss = 0
    time_batch_start = time.time()
    model.train()
    for bIdx, sample in enumerate(train_loader):
        ...
        train...
        validation...
        train_states_latest = {
          'epoch': epoch + 1,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'model_save_criteria': chosen_criteria}
        train_states['train_states_latest'] = train_states_latest
        

        # This is the code you can add
        FILEPATH_MODEL_SAVE = "Epoch{}batch{}model_weights.pth".format(epoch, bIdx)
        torch.save(train_states, FILEPATH_MODEL_SAVE)


With this new code above torch.save you avoid overwriting the checkpoint.

Sarthak

SarthakJain
  • 1,226
  • 6
  • 11
  • But that will save multiple models, to be exact if I train for 200 epochs and each epoch having 60 batches, there will be 200x60 models. The for loop code I shared already works with just saving two states(best and latest) in a single path. I just wanted to implement that in a function as I showed. – banikr Jul 24 '21 at 20:51
  • Im not quite understanding what exactly you want the new code to do. So I understand that you already got it to save latest and best weights. What new functionality are you aiming for. – SarthakJain Jul 25 '21 at 03:38
  • @ sarthakjain sorry for not clarifying. The `for` loop I posted already works perfectly. I want to implement that as a function. I wrote/initiated some of the codes. And whether the `checkpoint={}` is correct or not is confusing me. – banikr Jul 25 '21 at 04:48