Intro
I am making a classifier to recognize presence of defects in pictures, and in the path of improving my models, I tried Batch Normalization, mainly to exploit its ability to fasten convergence.
While it gives the expected speed benefits, I also observed some strange symptoms:
- validation metrics are far from good. It smells of overfitting of course
- predictions calculated at any point during training are completely wrong, particularly when images are picked from the training dataset; the corresponding metrics match with the (val_loss, val_acc) rather than with (loss, acc) printed during training
This failing to predict is the evidence that worries me the most. A model which does not predict the same as in training, is useless!
Searches
Googling around I found some posts that seem to be related, particularly this one (Keras BN layer is broken) which also claims the existence of a patch and of a pull request, that sadly "was rejected".
This is quite convincing, in that it explains a failure mechanism that matches my observations. As far as I understand, since BN calculates and keeps moving statistics (exponential averages and standard deviations) for doing its job, which require many iterations to stabilize and become significant, of course it will behave bad when it comes to make a prediction from scratch, when those statistics are not mature enough (in case I have misunderstood this concept, please tell me).
Actual Questions
But thinking more thoroughly, this doesn't really close the issue, and actually raises further doubts. I am still perplexed that:
- This Keras BN being broken, is said to affect the use case of transfer learning, while mine is a classical case of a convolutional classifier, trained starting form standard glorot initialization. This should have been complained about by thousands of users, while instead there isn't much discussion about)
- technically: if my understanding is correct, why aren't these statistics (since they are so fundamental for prediction) saved in the model, so that their latest update is available to make a prediction? It seems perfectly feasible to keep and use them at prediction time, as for any trainable parameter
- managementwise: if Keras' BN were really broken, how could such a deadful bug remain unaddressed for more than one year? Isn't really out there anybody using BN and needing predictions out of their models? And not even anybody able to fix it?
- more practically: on the contrary, if it is not a bug, but just a bad understanding on how to use it, were do I get a clear illustration of "how to correctly get a prediction in Keras for a model which uses BN?" (demo code would be appreciated)
Obviously I would really love that the right questions is the last, but I had to include the previous ones, given the evidence of someone claiming that Keras BN is broken.
Note to SE OP: before *closing the question as too broad*, please consider that, being not really clear what the issue is (Keras BN being broken, or the users being unable to use it properly), I had to offer more directions, among which whoever wishing to answer can choose.
Details
- I am using keras 2.2.4 from a python 3.6 virtual environment (under pyenv/virtualenv).
- data are fed through a classic
ImageDataGenerator()
+flow_from_directory()
/flow_from_dataframe()
scheme (augmentation is turned off though: onlyrescale=1./255
is applied), but I also tried to make them static - actually in the end, for verifying the above behaviour, I generated only one dataset
x,y=next(valid_generator)
and used an unique batch scheme for both training and validation. While on the training side it converges (yes, the aim was exactly to let it overfit!), on the validation side both metrics are poor and predictions are completely wrong and erratic (almost random) - in this setup, if BN is turned off,
val_loss
andval_acc
match exactly withloss
andacc
, and with those that I can obtain from predictions calulated after training has finished.
Update
In the process of writing a minimal example of the issue, after battling to put in evidence the problem, I recognized that the problem is showing/not showing up in different machines. In particular, the problem is evident on a host running Keras 2.3.1, while another host with Keras 2.2.4 doesn't show it. I'll post a minimal example here along with specific module versions asap.