Why is it recommended to save the state dicts and load them instead of saving stuff with dill
for example and then just getting the usable objects immediately?
I think I've done that without may issues and it saves users code.
But instead we are recommended to do something like:
def _load_model_and_optimizer_from_checkpoint(args: Namespace, training: bool = True) -> Namespace:
"""
based from: https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
"""
import torch
from torch import optim
import torch.nn as nn
# model = Net()
args.model = nn.Linear()
# optimizer = optim.SGD(args.model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(args.model.parameters(), lr=0.001)
# scheduler...
checkpoint = torch.load(args.PATH)
args.model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
args.epoch_num = checkpoint['epoch_num']
args.loss = checkpoint['loss']
args.model.train() if training else args.model.eval()
For example I've saved:
def save_for_meta_learning(args: Namespace, ckpt_filename: str = 'ckpt.pt'):
if is_lead_worker(args.rank):
import dill
args.logger.save_current_plots_and_stats()
# - ckpt
assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
f: nn.Module = get_model_from_ddp(args.base_model)
# pickle vs torch.save https://discuss.pytorch.org/t/advantages-disadvantages-of-using-pickle-module-to-save-models-vs-torch-save/79016
args_pickable: Namespace = uutils.make_args_pickable(args)
torch.save({'training_mode': args.training_mode, # assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
'it': args.it,
'epoch_num': args.epoch_num,
'args': args_pickable, # some versions of this might not have args!
'meta_learner': args.meta_learner,
'meta_learner_str': str(args.meta_learner), # added later, to make it easier to check what optimizer was used
'f': f,
'f_state_dict': f.state_dict(), # added later, to make it easier to check what optimizer was used
'f_str': str(f), # added later, to make it easier to check what optimizer was used
# 'f_modules': f._modules,
# 'f_modules_str': str(f._modules),
'outer_opt': args.outer_opt, # added later, to make it easier to check what optimizer was used
'outer_opt_state_dict': args.outer_opt.state_dict(), # added later, to make it easier to check what optimizer was used
'outer_opt_str': str(args.outer_opt) # added later, to make it easier to check what optimizer was used
},
pickle_module=dill,
f=args.log_root / ckpt_filename)
then loaded:
def get_model_opt_meta_learner_to_resume_checkpoint_resnets_rfs(args: Namespace,
path2ckpt: str,
filename: str,
device: Optional[torch.device] = None
) -> tuple[nn.Module, optim.Optimizer, MetaLearner]:
"""
Get the model, optimizer, meta_learner to resume training from checkpoint.
Examples:
- see: _resume_from_checkpoint_meta_learning_for_resnets_rfs_test
"""
import uutils
path2ckpt: Path = Path(path2ckpt).expanduser() if isinstance(path2ckpt, str) else path2ckpt.expanduser()
ckpt: dict = torch.load(path2ckpt / filename, map_location=torch.device('cpu'))
# args_ckpt: Namespace = ckpt['args']
training_mode = ckpt.get('training_mode')
if training_mode is not None:
assert uutils.xor(training_mode == 'epochs', training_mode == 'iterations')
if training_mode == 'epochs':
args.epoch_num = ckpt['epoch_num']
else:
args.it = ckpt['it']
# - get meta-learner
meta_learner: MetaLearner = ckpt['meta_learner']
# - get model
model: nn.Module = meta_learner.base_model
# - get outer-opt
outer_opt_str = ckpt.get('outer_opt_str')
if outer_opt_str is not None:
# use the string to create optimizer, load the state dict, etc.
outer_opt: optim.Optimizer = get_optimizer(outer_opt_str)
outer_opt_state_dict: dict = ckpt['outer_opt_state_dict']
outer_opt.load_state_dict(outer_opt_state_dict)
else:
# this is not ideal, but since Adam has a exponentially moving average for it's adaptive learning rate,
# hopefully this doesn't screw my checkpoint to much
outer_opt: optim.Optimizer = optim.Adam(model.parameters(), lr=args.outer_lr)
# - device setup
if device is not None:
# if torch.cuda.is_available():
# meta_learner.base_model = meta_learner.base_model.cuda()
meta_learner.base_model = meta_learner.base_model.to(device)
return model, outer_opt, meta_learner
without issues.
Related:
- Save and load model optimizer state
- pytorch save and load model
- Save and load a Pytorch model
- save and load unserialized pytorch pretrained model
- https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
- Why is it not recommended to save the optimizer, model etc as pickable/dillable objs in PyTorch but instead get the state dicts and load them?
- https://discuss.pytorch.org/t/why-is-it-not-recommended-to-save-the-optimizer-model-etc-as-pickable-dillable-objs-in-pytorch-but-instead-get-the-state-dicts-and-load-them/137933