6

I currently have my trainer set up as:

training_args = TrainingArguments(
    output_dir=f"./results_{model_checkpoint}",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    push_to_hub=True,
    save_total_limit = 1,
    resume_from_checkpoint=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_qa["train"],
    eval_dataset=tokenized_qa["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics
)

After training, in my output_dir I have several files that the trainer saved:

['README.md',
 'tokenizer.json',
 'training_args.bin',
 '.git',
 '.gitignore',
 'vocab.txt',
 'config.json',
 'checkpoint-5000',
 'pytorch_model.bin',
 'tokenizer_config.json',
 'special_tokens_map.json',
 '.gitattributes']

From the documentation it seems that resume_from_checkpoint will continue training the model from the last checkpoint:

resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.

But when I call trainer.train() it seems to delete the last checkpoint and start a new one:

Saving model checkpoint to ./results_distilbert-base-uncased/checkpoint-500
...
Deleting older checkpoint [results_distilbert-base-uncased/checkpoint-5000] due to args.save_total_limit

Does it really continue training from the last checkpoint (i.e., 5000) and just starts the count of the new checkpoint at 0 (saves the first after 500 steps -- "checkpoint-500"), or does it simply not continue the training? I haven't found a way to test it and the documentation is not clear on that.

Penguin
  • 1,923
  • 3
  • 21
  • 51

3 Answers3

2

Looking at the code, it first loads the checkpoint state, updates how many epochs have already been run, and continues training from there to the total number of epochs you're running the job for (no reset to 0).

To see it continue training, increase your num_train_epochs before calling trainer.train() on your checkpoint.

Suraj813
  • 21
  • 2
  • I tried increasing the epochs to 220 after training it so far on 110. It still started the counter from 0 (it finished the 110 epochs on step 16500, but started from 0 for the 220 epoch call) – Penguin Jun 27 '22 at 21:19
2

You also should add resume_from_checkpoint parametr to trainer.train with link to checkpoint

trainer.train(resume_from_checkpoint="{<path-where-checkpoint-were_stored>/checkpoint-0000")

0000- example of checkpoin number.

Don't forget to mount your drive during whole this process.

  • link to doc https://huggingface.co/docs/transformers/main_classes/trainer#checkpoints – Halyna Symonets Jul 18 '22 at 13:54
  • The "boolean" approach `resume_from_checkpoint=True` mentions "which will resume training from the latest checkpoint". Which unless I'm missing something implies that it will automatically use the last checkpoint, so I don't need to specify the last one with a path. Nevertheless, it doesn't really answer my question about whether it works (e.g., the counts problem) – Penguin Jul 18 '22 at 16:48
  • 1
    Yes, you're rigth, they promise it to work like you said. I just wanted to make sure that trainer uses right checkpoint, so I specified path. Idk how it works only with boolean param, try to ask on hugging face thread – Halyna Symonets Jul 19 '22 at 19:53
2

Yes it works! When you call trainer.train() you're implicitly telling it to override all checkpoints and start from scratch. You should call trainer.train(resume_from_checkpoint=True) or set resume_from_checkpoint to a string pointing to the checkpoint path.

Wayne DSouza
  • 151
  • 1
  • 4
  • will this also work if previously saved checkpoints were created by a script with trainer.train(), and now I changed the script to trainer.train(resume_from_checkpoint=True) and run it? – Danny Aug 21 '23 at 02:46