Q: Is it 'model.train()' really necessary?
A: What does model.train() do in PyTorch?
And reference to code: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train
But is it necessary?
Compulsory, no. Necessary, most probably. If the model is loaded without specifying the train/eval mode, it might behave unexpectedly.
Q: Why does it use 'optimizer.zero_grad()' twice?
Depends on what's inside the loss_fn
operation and model.forward()
function.
What does the optimizer.zero_grad()
do?
A: Why do we need to call zero_grad() in PyTorch?
Normally, you have the code flow as such:
pred = model(...)
Model forward propagation
loss = func(pred, truth)
Loss computation with the gradient information for back propagation for each node
loss.backwards()
Back-propagate the gradients
optimizer.step()
update the optimizer based on the back-propagated gradients
optimizer.zero_grad()
reset all the information before the next gradient computation
The normal 1,2,3,4,5
flow is common and then it repeats, 1,2,3,4,5, 1,2,3,4,5...
but sometimes you can also see 5,1,2,3,4, 5,1,2,3,4,
Q: Then why did the example code have 1,2,5,3,5,4
flow?
It's kind of strange and most probably not the normal deep learning flow.
If you clear the gradients before computing the backpropagation, isn't there no backpropagation? I.e. 5 before 3,4 is kind of weird.
But there code be other nuance that is unclear, e.g. does the model.forward()
have some mini-back propagation inside? Normally it doesn't but if it does, the author might want to get rid of them before the loss computation. But still 5, before 3,4 is NOT common, maybe it's a neural net without backpropagation or something like https://arxiv.org/abs/2212.13345
Also, there might be other reasons like
maybe the gradients are accumulated across multiple batches before the loss computation and somehow when you copied+paste the code and then you added the zero_grad() func to the end of the batch loop.
But when you remove the gradient accumulation if-else but left some of the artefacts that controls the gradient accumulation thus another zero_grad() function looks like it appears out of nowhere.
Q: What is gradient accumulation?
See https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation