209

Does it call forward() in nn.Module? I thought when we call the model, forward method is being used. Why do we need to specify train()?

aerin
  • 20,607
  • 28
  • 102
  • 140
  • 3
    These days there is exist a documentation inside PyTorch: https://pytorch.org/docs/stable/generated/torch.nn.Module.html you can check documentation, it describes pretty clear I think. Another libraries/frameworks can have lack of documentation, but in PyTorch I think official documentation is pretty nice. – Konstantin Burlachenko Oct 19 '21 at 12:34
  • Perhaps "configure_training" or "set_training_mode" would have been better names for this function. – Rexcirus Dec 01 '21 at 13:42
  • 1
    it simple changes the `self.training` via `self.training = training` recursively for all modules by doing `self.train(False)`. In fact that is what `self.train` does, changes the flag to true recursively for all modules. see code: https://github.com/pytorch/pytorch/blob/6e1a5b1196aa0277a2113a4bca75b6e0f2b4c0c8/torch/nn/modules/module.py#L1432 – Charlie Parker Dec 19 '21 at 19:07

6 Answers6

276

model.train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

More details: model.train() sets the mode to train (see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing. It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
Umang Gupta
  • 15,022
  • 6
  • 48
  • 66
  • 4
    is there a flag to detect if the model is in eval mode? e.g. `mdl.is_eval()`? – Charlie Parker May 12 '21 at 17:43
  • 5
    Use `model.training` flag. It is `False`, when in `eval` mode. – Umang Gupta May 12 '21 at 18:16
  • In the current documentation, I find this "model.train()" is no longer being used: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html I did a small test with a small 3 layer neural network model with batch norm and dropout and trained it on tabular dataset. I found adding model.train() actually prevented my model accuracy going above 70%. When I removed the line, the accuracy was 87%! – Indrajit Aug 07 '21 at 12:43
  • @Indrajit Did you check that was not in train model, i.e., `model.training` is `False`? I think by default it is true and that is why they omit `model.train()` call. As for your result, I cannot say much without knowing what the data was and if you measure test or train accuracy etc. – Umang Gupta Aug 16 '21 at 21:47
  • @UmangGupta - It is true that by default model.training is True, but if you look at the link, their training loop, after the train step they have an eval step - where they call model.eval(). This will make model.training as False which they don't reset. I understand this is pretty counter intuitive - and I am puzzled too. Still trying to understand why this is happening. – Indrajit Aug 19 '21 at 11:13
  • 2
    @UmangGupta- actually I figured out just now what was happening. My model.train() was actually impacting batchnorm and dropout layers - which in turn was impacting the model performance. – Indrajit Aug 19 '21 at 11:49
  • I wonder how `model.eval()` affects backward pass? – mrgloom Dec 06 '21 at 18:32
  • It won't affect the backward pass directly. It will just use batch norm and dropout in "test" mode which can affect the backward pass. For example, Dropout would become an identity operation and thus have no effect. – Umang Gupta Dec 06 '21 at 21:01
91

Here is the code for nn.Module.train():

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

Here is the code for nn.Module.eval():

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

By default, the self.training flag is set to True, i.e., modules are in train mode by default. When self.training is False, the module is in the opposite state, eval mode.

Of the most commonly used layers, only Dropout and BatchNorm care about that flag.

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
prosti
  • 42,291
  • 14
  • 186
  • 151
  • Are there any other layers that support `self.training` flag now ? – Melike Mar 05 '21 at 18:13
  • I wonder how `model.eval()` affects backward pass? – mrgloom Dec 06 '21 at 18:32
  • `model.eval()` is is just a switch not to take dropout and batch norms. I have a nice [intro to PyTorch training](https://programming-review.com/pytorch/train) where you can check the forward and backward pass, and [deep intro to PyTorch AD](https://programming-review.com/pytorch/ad) where you can confidently understand the details of PyTorch AD. – prosti Dec 06 '21 at 18:45
  • 1
    @Melike https://stackoverflow.com/questions/66534762/which-pytorch-modules-are-affected-by-model-eval-and-model-train – iacob Jul 29 '22 at 11:24
38
model.train() model.eval()
Sets model in training mode i.e.

BatchNorm layers use per-batch statistics
Dropout layers activated etc
Sets model in evaluation (inference) mode i.e.

BatchNorm layers use running statistics
Dropout layers de-activated etc
Equivalent to model.train(False).

Note: neither of these function calls run forward / backward passes. They tell the model how to act when run.

This is important as some modules (layers) (e.g. Dropout, BatchNorm) are designed to behave differently during training vs inference, and hence the model will produce unexpected results if run in the wrong mode.

iacob
  • 20,084
  • 6
  • 92
  • 119
16

There are two ways of letting the model know your intention i.e do you want to train the model or do you want to use the model to evaluate. In case of model.train() the model knows it has to learn the layers and when we use model.eval() it indicates the model that nothing new is to be learnt and the model is used for testing. model.eval() is also necessary because in pytorch if we are using batchnorm and during test if we want to just pass a single image, pytorch throws an error if model.eval() is not specified.

Chris Tang
  • 567
  • 7
  • 18
kelam gautam
  • 176
  • 2
1

The current official documentation states the following:

This has any [sic] effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

iacob
  • 20,084
  • 6
  • 92
  • 119
Konstantin Burlachenko
  • 5,233
  • 2
  • 41
  • 40
1

Consider the following model

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training) #Look here
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

Here, the functioning of dropout differ in different modes of operation. As you can see, it works only when self.training==True. So, when you type model.train(), the model's forward function will perform dropout otherwise it will not (say when model.eval() or model.train(mode=False)).

Lawhatre
  • 1,302
  • 2
  • 10
  • 28