Does it call forward()
in nn.Module
? I thought when we call the model, forward
method is being used.
Why do we need to specify train()?

- 20,607
- 28
- 102
- 140
-
3These days there is exist a documentation inside PyTorch: https://pytorch.org/docs/stable/generated/torch.nn.Module.html you can check documentation, it describes pretty clear I think. Another libraries/frameworks can have lack of documentation, but in PyTorch I think official documentation is pretty nice. – Konstantin Burlachenko Oct 19 '21 at 12:34
-
Perhaps "configure_training" or "set_training_mode" would have been better names for this function. – Rexcirus Dec 01 '21 at 13:42
-
1it simple changes the `self.training` via `self.training = training` recursively for all modules by doing `self.train(False)`. In fact that is what `self.train` does, changes the flag to true recursively for all modules. see code: https://github.com/pytorch/pytorch/blob/6e1a5b1196aa0277a2113a4bca75b6e0f2b4c0c8/torch/nn/modules/module.py#L1432 – Charlie Parker Dec 19 '21 at 19:07
6 Answers
model.train()
tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.
More details:
model.train()
sets the mode to train
(see source code). You can call either model.eval()
or model.train(mode=False)
to tell that you are testing.
It is somewhat intuitive to expect train
function to train model but it does not do that. It just sets the mode.

- 24,552
- 19
- 101
- 135

- 15,022
- 6
- 48
- 66
-
4is there a flag to detect if the model is in eval mode? e.g. `mdl.is_eval()`? – Charlie Parker May 12 '21 at 17:43
-
5
-
In the current documentation, I find this "model.train()" is no longer being used: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html I did a small test with a small 3 layer neural network model with batch norm and dropout and trained it on tabular dataset. I found adding model.train() actually prevented my model accuracy going above 70%. When I removed the line, the accuracy was 87%! – Indrajit Aug 07 '21 at 12:43
-
@Indrajit Did you check that was not in train model, i.e., `model.training` is `False`? I think by default it is true and that is why they omit `model.train()` call. As for your result, I cannot say much without knowing what the data was and if you measure test or train accuracy etc. – Umang Gupta Aug 16 '21 at 21:47
-
@UmangGupta - It is true that by default model.training is True, but if you look at the link, their training loop, after the train step they have an eval step - where they call model.eval(). This will make model.training as False which they don't reset. I understand this is pretty counter intuitive - and I am puzzled too. Still trying to understand why this is happening. – Indrajit Aug 19 '21 at 11:13
-
2@UmangGupta- actually I figured out just now what was happening. My model.train() was actually impacting batchnorm and dropout layers - which in turn was impacting the model performance. – Indrajit Aug 19 '21 at 11:49
-
-
It won't affect the backward pass directly. It will just use batch norm and dropout in "test" mode which can affect the backward pass. For example, Dropout would become an identity operation and thus have no effect. – Umang Gupta Dec 06 '21 at 21:01
Here is the code for nn.Module.train()
:
def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self
Here is the code for nn.Module.eval()
:
def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)
By default, the self.training
flag is set to True
, i.e., modules are in train mode by default. When self.training
is False
, the module is in the opposite state, eval mode.
Of the most commonly used layers, only Dropout
and BatchNorm
care about that flag.

- 24,552
- 19
- 101
- 135

- 42,291
- 14
- 186
- 151
-
-
-
`model.eval()` is is just a switch not to take dropout and batch norms. I have a nice [intro to PyTorch training](https://programming-review.com/pytorch/train) where you can check the forward and backward pass, and [deep intro to PyTorch AD](https://programming-review.com/pytorch/ad) where you can confidently understand the details of PyTorch AD. – prosti Dec 06 '21 at 18:45
-
1@Melike https://stackoverflow.com/questions/66534762/which-pytorch-modules-are-affected-by-model-eval-and-model-train – iacob Jul 29 '22 at 11:24
model.train() |
model.eval() |
---|---|
Sets model in training mode i.e. • BatchNorm layers use per-batch statistics • Dropout layers activated etc |
Sets model in evaluation (inference) mode i.e. • BatchNorm layers use running statistics • Dropout layers de-activated etc |
Equivalent to model.train(False) . |
Note: neither of these function calls run forward / backward passes. They tell the model how to act when run.
This is important as some modules (layers) (e.g. Dropout
, BatchNorm
) are designed to behave differently during training vs inference, and hence the model will produce unexpected results if run in the wrong mode.

- 20,084
- 6
- 92
- 119
There are two ways of letting the model know your intention i.e do you want to train the model or do you want to use the model to evaluate.
In case of model.train()
the model knows it has to learn the layers and when we use model.eval()
it indicates the model that nothing new is to be learnt and the model is used for testing.
model.eval()
is also necessary because in pytorch if we are using batchnorm and during test if we want to just pass a single image, pytorch throws an error if model.eval()
is not specified.

- 567
- 7
- 18

- 176
- 2
The current official documentation states the following:
This has any [sic] effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

- 20,084
- 6
- 92
- 119

- 5,233
- 2
- 41
- 40
Consider the following model
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphNet(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphNet, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.dropout(x, training=self.training) #Look here
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Here, the functioning of dropout
differ in different modes of operation. As you can see, it works only when self.training==True
. So, when you type model.train()
, the model's forward function will perform dropout otherwise it will not (say when model.eval()
or model.train(mode=False)
).

- 1,302
- 2
- 10
- 28