I'm running a pytroch model that has already been written. In that code authors save checkpoints (ckpt) files each epoch. But the authors didn't give any option to resume the training from one of this checkpoints. Here is the authors original code.
trainer = pl.Trainer.from_argparse_args(args,
default_root_dir=args.logdir,
gpus = args.gpus,
accelerator='ddp',
sync_batchnorm=True,
plugins=DDPPlugin(find_unused_parameters=False),
profiler='simple',
benchmark=True,
log_every_n_steps=1,
flush_logs_every_n_steps=5,
callbacks=[checkpoint_callback,
],
check_val_every_n_epoch = args.val_every,
max_epochs = args.epochs,
logger=logger
)
So I changed the above code to start running from a given checkpoint using command line arguments. Here is what I have done.
if args.loadfromcheckpoint>0:
trainer = pl.Trainer(
resume_from_checkpoint=args.logdir+"/epoch={checkpoint}-last.ckpt".format(checkpoint=args.loadfromcheckpoint),
default_root_dir=args.logdir,
gpus = args.gpus,
accelerator='ddp',
sync_batchnorm=True,
plugins=DDPPlugin(find_unused_parameters=False),
profiler='simple',
benchmark=True,
log_every_n_steps=1,
flush_logs_every_n_steps=5,
callbacks=[checkpoint_callback,
],
check_val_every_n_epoch = args.val_every,
max_epochs = args.epochs,
logger=logger)
trainer.fit(TCP_model, dataloader_train, dataloader_val)
else:
trainer.fit(TCP_model, dataloader_train, dataloader_val)
The above code works fine. Since I'm quite new to Pytorch and Pytorch Lightning I have following questions,
- Does the lightning API only restore state_dict or does it restore all such as optimzer_states, lr_schedulers as well.
- If lightning doesn't load all those, how to load those states manually.