2

Why is it recommended to save the state dicts and load them instead of saving stuff with dill for example and then just getting the usable objects immediately?

I think I've done that without may issues and it saves users code.

But instead we are recommended to do something like:

def _load_model_and_optimizer_from_checkpoint(args: Namespace, training: bool = True) -> Namespace:
    """
    based from: https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
    """
    import torch
    from torch import optim
    import torch.nn as nn
    # model = Net()
    args.model = nn.Linear()
    # optimizer = optim.SGD(args.model.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(args.model.parameters(), lr=0.001)

    # scheduler...

    checkpoint = torch.load(args.PATH)
    args.model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    args.epoch_num = checkpoint['epoch_num']
    args.loss = checkpoint['loss']

    args.model.train() if training else args.model.eval()

For example I've saved:

def save_for_meta_learning(args: Namespace, ckpt_filename: str = 'ckpt.pt'):
    if is_lead_worker(args.rank):
        import dill
        args.logger.save_current_plots_and_stats()
        # - ckpt
        assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
        f: nn.Module = get_model_from_ddp(args.base_model)
        # pickle vs torch.save https://discuss.pytorch.org/t/advantages-disadvantages-of-using-pickle-module-to-save-models-vs-torch-save/79016
        args_pickable: Namespace = uutils.make_args_pickable(args)
        torch.save({'training_mode': args.training_mode,  # assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
                    'it': args.it,
                    'epoch_num': args.epoch_num,

                    'args': args_pickable,  # some versions of this might not have args!

                    'meta_learner': args.meta_learner,
                    'meta_learner_str': str(args.meta_learner),  # added later, to make it easier to check what optimizer was used

                    'f': f,
                    'f_state_dict': f.state_dict(),  # added later, to make it easier to check what optimizer was used
                    'f_str': str(f),  # added later, to make it easier to check what optimizer was used
                    # 'f_modules': f._modules,
                    # 'f_modules_str': str(f._modules),

                    'outer_opt': args.outer_opt,  # added later, to make it easier to check what optimizer was used
                    'outer_opt_state_dict': args.outer_opt.state_dict(),  # added later, to make it easier to check what optimizer was used
                    'outer_opt_str': str(args.outer_opt)  # added later, to make it easier to check what optimizer was used
                    },
                   pickle_module=dill,
                   f=args.log_root / ckpt_filename)

then loaded:

def get_model_opt_meta_learner_to_resume_checkpoint_resnets_rfs(args: Namespace,
                                                                path2ckpt: str,
                                                                filename: str,
                                                                device: Optional[torch.device] = None
                                                                ) -> tuple[nn.Module, optim.Optimizer, MetaLearner]:
    """
    Get the model, optimizer, meta_learner to resume training from checkpoint.

    Examples:
        - see: _resume_from_checkpoint_meta_learning_for_resnets_rfs_test
    """
    import uutils
    path2ckpt: Path = Path(path2ckpt).expanduser() if isinstance(path2ckpt, str) else path2ckpt.expanduser()
    ckpt: dict = torch.load(path2ckpt / filename, map_location=torch.device('cpu'))
    # args_ckpt: Namespace = ckpt['args']
    training_mode = ckpt.get('training_mode')
    if training_mode is not None:
        assert uutils.xor(training_mode == 'epochs', training_mode == 'iterations')
        if training_mode == 'epochs':
            args.epoch_num = ckpt['epoch_num']
        else:
            args.it = ckpt['it']
    # - get meta-learner
    meta_learner: MetaLearner = ckpt['meta_learner']
    # - get model
    model: nn.Module = meta_learner.base_model
    # - get outer-opt
    outer_opt_str = ckpt.get('outer_opt_str')
    if outer_opt_str is not None:
        # use the string to create optimizer, load the state dict, etc.
        outer_opt: optim.Optimizer = get_optimizer(outer_opt_str)
        outer_opt_state_dict: dict = ckpt['outer_opt_state_dict']
        outer_opt.load_state_dict(outer_opt_state_dict)
    else:
        # this is not ideal, but since Adam has a exponentially moving average for it's adaptive learning rate,
        # hopefully this doesn't screw my checkpoint to much
        outer_opt: optim.Optimizer = optim.Adam(model.parameters(), lr=args.outer_lr)
    # - device setup
    if device is not None:
        # if torch.cuda.is_available():
        #     meta_learner.base_model = meta_learner.base_model.cuda()
        meta_learner.base_model = meta_learner.base_model.to(device)
    return model, outer_opt, meta_learner

without issues.


Related:

Charlie Parker
  • 5,884
  • 57
  • 198
  • 323
  • cross posted: https://discuss.pytorch.org/t/why-is-it-not-recommended-to-save-the-optimizer-model-etc-as-pickable-dillable-objs-in-pytorch-but-instead-get-the-state-dicts-and-load-them/137933 – Charlie Parker Nov 26 '21 at 20:49
  • I think the main reason is this: https://stackoverflow.com/questions/70341854/how-to-open-a-pickled-file-with-dill-where-the-objects-are-not-findable-anymore basically, if you save objects and refactor the code it might be that dill has issues working since it doesn't know where the code is...perhaps making your checkpoint unusable. – Charlie Parker Dec 14 '21 at 00:11

0 Answers0