I usually alternate between training and testing in my code. This results in me calling mdl.train()
when coming back to the trianing loop and mdl.test()
when coming to the testing loop. However I've noticed that potentially this creates issues with how parameters are saved (e.g. batch norm, perhaps others) that depend on this eval, train state. Thus, it begs the question how these flags should be called if at all. If I call evaluation before checkpointing I believe this will delete my running average as mentioned here: How does one use the mean and std from training in Batch Norm?.
I think a deep copy of the model during evaluation should fix things - but that seems like it will kill my GPU memory (and perhaps my normal memory too).
Thus, what is the proper way to alternate between evaluation and training in PyTorch so that checkpoints are saved properly (e.g. the running stats of training are NOT deleted).
This is a challenge because I usually always run evaluation to see if the current model is better in validation than the previous one and then decide on THAT to save it or not - which requires me to do an evaluation before checkpointing - always.
Perhaps, I can always run it on train so that the running averages are saved correctly but since the actual validation value doesn't matter it's fine if the stats from train and val leak to the validation.
I guess running everything with batch stats all the time is another option...
Some snipets of code I usually use for MAML (meta-learning but should be trivial to adapt to normal supervised learning):
def meta_train_fixed_iterations_full_epoch_possible(args):
"""
Train using the meta-training (e.g. episodic training) over batches of tasks using a fixed number of iterations
assuming the the number of tasks is small i.e. one epoch is doable and not infinite/super exponential
(e.g. in regression when a task can be considered as a function).
Note: if num tasks is small then we have two loops, one while we have not finished all fixed its and the other
over the dataloader for the tasks.
"""
# warnings.simplefilter("ignore")
# uutils.torch_uu.distributed.dist_log('Starting training...')
print('Strating training!')
# bar_it = uutils.get_good_progressbar(max_value=progressbar.UnknownLength)
bar_it = uutils.get_good_progressbar(max_value=args.num_its)
args.it = 0
while True:
for batch_idx, batch in enumerate(args.dataloaders['train']):
args.batch_idx = batch_idx
spt_x, spt_y, qry_x, qry_y = process_meta_batch(args, batch)
# - clean gradients, especially before meta-learner is ran since it uses gradients
args.outer_opt.zero_grad()
# - forward pass A(f)(x)
train_loss, train_acc = args.meta_learner(spt_x, spt_y, qry_x, qry_y)
# - outer_opt step
gradient_clip(args, args.outer_opt) # do gradient clipping: * If ‖g‖ ≥ c Then g := c * g/‖g‖
args.outer_opt.step()
# - scheduler
if (args.it % 500 == 0 and args.it != 0 and args.scheduler is not None) or args.debug: # call scheduler every
args.scheduler.step()
# -- log it stats
log_train_val_stats(args, args.it, train_loss, train_acc, valid=meta_eval, bar=bar_it,
log_freq=100, ckpt_freq=500,
save_val_ckpt=True, log_to_wandb=args.log_to_wandb)
log_sim_to_check_presence_of_feature_reuse(args, args.it,
spt_x, spt_y, qry_x, qry_y,
log_freq_for_detection_of_feature_reuse=int(args.num_its//3)
, parallel=False)
# - break
halt: bool = args.it >= args.num_its - 1
if halt:
return train_loss, train_acc
args.it += 1
# - evaluation code
def meta_eval(args: Namespace, val_iterations: int = 0, save_val_ckpt: bool = True, split: str = 'val') -> tuple:
"""
Evaluates the meta-learner on the given meta-set.
ref for BN/eval:
- https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du
- https://github.com/tristandeleu/pytorch-maml/issues/19
"""
# - need to re-implement if you want to go through the entire data-set to compute an epoch (no more is ever needed)
assert val_iterations == 0, f'Val iterations has to be zero but got {val_iterations}, if you want more precision increase (meta) batch size.'
args.meta_learner.eval()
for batch_idx, batch in enumerate(args.dataloaders[split]):
spt_x, spt_y, qry_x, qry_y = process_meta_batch(args, batch)
# Forward pass
eval_loss, eval_acc = args.meta_learner(spt_x, spt_y, qry_x, qry_y)
# store eval info
if batch_idx >= val_iterations:
break
save_val_ckpt = False if split == 'test' else save_val_ckpt # don't save models based on test set
if float(eval_loss) < float(args.best_val_loss) and save_val_ckpt:
args.best_val_loss = float(eval_loss)
save_for_meta_learning(args, ckpt_filename='ckpt_best_val.pt')
return eval_loss, eval_acc
related:
- When should one call .eval() and .train() when doing MAML with the PyTorch higher library?
- https://discuss.pytorch.org/t/how-does-one-use-the-mean-and-std-from-training-in-batch-norm/136029
- https://discuss.pytorch.org/t/proper-way-to-call-eval-and-train-so-that-checkpointing-saves-all-parameters/136100