-1

hi i'm going through pytorch tutorial about transfer learning. (https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

what is model.training for??

enter def visualize_model(model,num_images=6):
was_training=model.training
model.eval()
images_so_far=0
fig=plt.figure()

with torch.no_grad():
    for i, (inputs,labels) in enumerate(dataloaders['val']):
        inputs=inputs.to(device)
        labels=labels.to(device)
        
        outputs=model(inputs)
        _,pred=torch.max(outputs,1)
        
        for j in range(inputs.size()[0]):
            images_so_far+=1
            ax=plt.subplot(num_images//2,2,images_so_far)
            ax.axis('off')
            ax.set_title('predicted: {}'.format(class_names[preds[j]]))
            imshow(inputs.cpu().data[j])
            
            if images_so_far==num_images:
                model.train(mode=was_training)
                return
    model.train(mode=was_training)code here

i cannot understand "model.train(model=was_training)". any help?? thank you so much

sky park
  • 11
  • 1
  • Does this answer your question? [What does model.train() do in PyTorch?](https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch) – Ivan Sep 29 '21 at 06:54
  • oh thank you!! but now i wonder why they use model.train in the test session. why do they put that code inside the "with torch.no_grad()"?? isn't it obvious that was_training=false?? – sky park Sep 29 '21 at 07:03

2 Answers2

1

I think this will help (link)

All nn.Modules have an internal training attribute, which is changed by calling model.train() and model.eval() to switch the behavior of the model.

The was_training variable stores the current training state of the model, calls model.eval(), and resets the state at the end using model.train(training=was_training).

You can find great answers in pytorch discuss forum ;)

Skarl001
  • 92
  • 1
  • 9
0

I wonder why they use model.train in the test session. why do they put that code inside the with torch.no_grad()? Isn't it obvious that was_training=false?

It is a bit misleading usage of train because train can be used to put the model in inference (evaluation) mode as well:

>>> model.train(mode=True)
>>> model.training 
True   # <- train mode

>>> model.train(mode=False)
False  # <- eval mode

I agree it is not ideal, a more appropriate formulation would have been simply:

>>> model.eval()
Shai
  • 111,146
  • 38
  • 238
  • 371
Ivan
  • 34,531
  • 8
  • 55
  • 100