2

I usually alternate between training and testing in my code. This results in me calling mdl.train() when coming back to the trianing loop and mdl.test() when coming to the testing loop. However I've noticed that potentially this creates issues with how parameters are saved (e.g. batch norm, perhaps others) that depend on this eval, train state. Thus, it begs the question how these flags should be called if at all. If I call evaluation before checkpointing I believe this will delete my running average as mentioned here: How does one use the mean and std from training in Batch Norm?.

I think a deep copy of the model during evaluation should fix things - but that seems like it will kill my GPU memory (and perhaps my normal memory too).

Thus, what is the proper way to alternate between evaluation and training in PyTorch so that checkpoints are saved properly (e.g. the running stats of training are NOT deleted).

This is a challenge because I usually always run evaluation to see if the current model is better in validation than the previous one and then decide on THAT to save it or not - which requires me to do an evaluation before checkpointing - always.

Perhaps, I can always run it on train so that the running averages are saved correctly but since the actual validation value doesn't matter it's fine if the stats from train and val leak to the validation.

I guess running everything with batch stats all the time is another option...

Some snipets of code I usually use for MAML (meta-learning but should be trivial to adapt to normal supervised learning):

def meta_train_fixed_iterations_full_epoch_possible(args):
    """
    Train using the meta-training (e.g. episodic training) over batches of tasks using a fixed number of iterations
    assuming the the number of tasks is small i.e. one epoch is doable and not infinite/super exponential
    (e.g. in regression when a task can be considered as a function).

    Note: if num tasks is small then we have two loops, one while we have not finished all fixed its and the other
    over the dataloader for the tasks.
    """
    # warnings.simplefilter("ignore")
    # uutils.torch_uu.distributed.dist_log('Starting training...')
    print('Strating training!')

    # bar_it = uutils.get_good_progressbar(max_value=progressbar.UnknownLength)
    bar_it = uutils.get_good_progressbar(max_value=args.num_its)
    args.it = 0
    while True:
        for batch_idx, batch in enumerate(args.dataloaders['train']):
            args.batch_idx = batch_idx
            spt_x, spt_y, qry_x, qry_y = process_meta_batch(args, batch)

            # - clean gradients, especially before meta-learner is ran since it uses gradients
            args.outer_opt.zero_grad()

            # - forward pass A(f)(x)
            train_loss, train_acc = args.meta_learner(spt_x, spt_y, qry_x, qry_y)

            # - outer_opt step
            gradient_clip(args, args.outer_opt)  # do gradient clipping: * If ‖g‖ ≥ c Then g := c * g/‖g‖
            args.outer_opt.step()

            # - scheduler
            if (args.it % 500 == 0 and args.it != 0 and args.scheduler is not None) or args.debug:  # call scheduler every
                args.scheduler.step()

            # -- log it stats
            log_train_val_stats(args, args.it, train_loss, train_acc, valid=meta_eval, bar=bar_it,
                                log_freq=100, ckpt_freq=500,
                                save_val_ckpt=True, log_to_wandb=args.log_to_wandb)
            log_sim_to_check_presence_of_feature_reuse(args, args.it,
                                                       spt_x, spt_y, qry_x, qry_y,
                                                       log_freq_for_detection_of_feature_reuse=int(args.num_its//3)
                                                       , parallel=False)

            # - break
            halt: bool = args.it >= args.num_its - 1
            if halt:
                return train_loss, train_acc

            args.it += 1

# - evaluation code

def meta_eval(args: Namespace, val_iterations: int = 0, save_val_ckpt: bool = True, split: str = 'val') -> tuple:
    """
    Evaluates the meta-learner on the given meta-set.

    ref for BN/eval:
        - https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du
        - https://github.com/tristandeleu/pytorch-maml/issues/19
    """
    # - need to re-implement if you want to go through the entire data-set to compute an epoch (no more is ever needed)
    assert val_iterations == 0, f'Val iterations has to be zero but got {val_iterations}, if you want more precision increase (meta) batch size.'
    args.meta_learner.eval()
    for batch_idx, batch in enumerate(args.dataloaders[split]):
        spt_x, spt_y, qry_x, qry_y = process_meta_batch(args, batch)

        # Forward pass
        eval_loss, eval_acc = args.meta_learner(spt_x, spt_y, qry_x, qry_y)

        # store eval info
        if batch_idx >= val_iterations:
            break

    save_val_ckpt = False if split == 'test' else save_val_ckpt  # don't save models based on test set
    if float(eval_loss) < float(args.best_val_loss) and save_val_ckpt:
        args.best_val_loss = float(eval_loss)
        save_for_meta_learning(args, ckpt_filename='ckpt_best_val.pt')
    return eval_loss, eval_acc

related:

Charlie Parker
  • 5,884
  • 57
  • 198
  • 323
  • If your model's training is sensible to whether you delete the running average in a batch norm or not - then your architecture didn't learn the data distribution (yet) – Alexey S. Larionov Nov 05 '21 at 16:53
  • My point is, it shouldn't matter *that* much – Alexey S. Larionov Nov 05 '21 at 16:53
  • @AlexeyLarionov perhaps (though I am seeing divergence issues in testing when using batch stats). However, at the overall point, I am trying to figure out why on earth the code would not save the stats. I don't think it's good that values are being deleted without us know what is causing it. – Charlie Parker Nov 05 '21 at 17:05
  • So **the main mystery is to figure out how my models were saved and their running averages from training removed** ref: https://discuss.pytorch.org/t/how-does-one-use-the-mean-and-std-from-training-in-batch-norm/136029/5 – Charlie Parker Nov 05 '21 at 19:09

0 Answers0