I'm attempting to train a tensorflow model based on the popular slim implementation of mobilenet_v2
and am observing behaviour I cannot explain related (I think) to batch normalization.
Problem Summary
Model performance in inference mode improves initially but starts producing trivial inferences (all near-zeros) after a long period. Good performance continues when run in training mode, even on the evaluation dataset. Evaluation performance is impacted by batch normalization decay/momentum rate... somehow.
More extensive implementation details below, but I'll probably lose most of you with the wall of text, so here are some pictures to get you interested.
The curves below are from a model which I tweaked the bn_decay
parameter of while training.
0-370k: bn_decay=0.997
(default)
370k-670k: bn_decay=0.9
670k+: bn_decay=0.5
Loss for (orange) training (in training mode) and (blue) evaluation (in inference mode). Low is good.
Evaluation metric of model on evaluation dataset in inference mode. High is good.
I have attempted to produce a minimal example which demonstrates the issue - classification on MNIST - but have failed (i.e. classification works well and the problem I experience is not exhibited). My apologies for not being able to reduce things further.
Implementation Details
My problem is 2D pose estimation, targeting Gaussians centered at the joint locations. It is essentially the same as semantic segmentation, except rather than using a softmax_cross_entropy_with_logits(labels, logits)
I use tf.losses.l2_loss(sigmoid(logits) - gaussian(label_2d_points))
(I use the term "logits" to describe unactivated output of my learned model, though this probably isn't the best term).
Inference Model
After preprocessing my inputs, my logits function is a scoped call to the base mobilenet_v2 followed by a single unactivated convolutional layer to make the number of filters appropriate.
from slim.nets.mobilenet import mobilenet_v2
def get_logtis(image):
with mobilenet_v2.training_scope(
is_training=is_training, bn_decay=bn_decay):
base, _ = mobilenet_v2.mobilenet(image, base_only=True)
logits = tf.layers.conv2d(base, n_joints, 1, 1)
return logits
Training Op
I have experimented with tf.contrib.slim.learning.create_train_op
as well as a custom training op:
def get_train_op(optimizer, loss):
global_step = tf.train.get_or_create_global_step()
opt_op = optimizer.minimize(loss, global_step)
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
update_ops.add(opt_op)
return tf.group(*update_ops)
I'm using tf.train.AdamOptimizer
with learning rate=1e-3
.
Training Loop
I'm using the tf.estimator.Estimator
API for training/evaluation.
Behaviour
Training initially goes well, with an expected sharp increase in performance. This is consistent with my expectations, as the final layer is rapidly trained to interpret the high-level features output by the pretrained base model.
However, after a long period (60k steps with batch_size 8, ~8 hours on a GTX-1070) my model begins to output near-zero values (~1e-11) when run in inference mode, i.e. is_training=False
. The exact same model continues to improve when run in *training mode, i.e.
is_training=True`, even on the valuation set. I have visually verified this is.
After some experimentation I changed the bn_decay
(batch normalization decay/momentum rate) from the default 0.997
to 0.9
at ~370k steps (also tried 0.99
, but that didn't make much of a difference) and observed an immdeiate improvement in accuracy. Visual inspection of the inference in inference mode showed clear peaks in the inferred values of order ~1e-1
in the expected places, consistent with the location of peaks from training mode (though values much lower). This is why the accuracy increases significantly, but the loss - while more volative - does not improve much.
These effects dropped off after more training and reverted to all zero inference.
I further dropped the bn_decay
to 0.5 at step ~670k. This resulted in improvements to both loss and accuracy. I'll likely have to wait until tomorrow to see the long-term effect.
Loss and an evaluation metric plots given below. Note the evaluation metric is based on the argmax of the logits and high is good. Loss is based on the actual values, and low is good. Orange uses is_training=True
on the training set, while blue uses is_training=False
on the evaluation set. The loss of around 8 is consistent with all zero outputs.
Other notes
- I have also experimented with turning off dropout (i.e. always running the dropout layers with
is_training=False
), and observed no difference. - I have experimented with all versions of tensorflow from
1.7
to1.10
. No difference. - I have trained models from the pretrained checkpoint using
bn_decay=0.99
from the start. Same behaviour as using defaultbn_decay
. - Other experiments with a batch size of 16 result in qualitatively identical behaviour (though I can't evaluate and train simultaneously due to memory constraints, hence quantitatively analysing on batch size of 8).
- I have trained different models using the same loss and using
tf.layers
API and trained from scratch. They have worked fine. - Training from scratch (rather than using pretrained checkpoints) results in similar behaviour, though takes longer.
Summary/my thoughts:
- I am confident this is not an overfitting/dataset problem. The model makes sensible inferences on the evaluation set when run with
is_training=True
, both in terms of location of peaks and magnitude. - I am confident this is not a problem with not running update ops. I haven't used
slim
before, but apart from the use ofarg_scope
it doesn't look too much different to thetf.layers
API which I've used extensively. I can also inspect the moving average values and observe that they are changing as training progresses. - Chaning
bn_decay
values significantly effected the results temporarily. I accept that a value of0.5
is absurdly low, but I'm running out of ideas. - I have tried swapping out
slim.layers.conv2d
layers fortf.layers.conv2d
withmomentum=0.997
(i.e. momentum consistent with default decay value) and behaviour was the same. - Minimal example using pretrained weights and
Estimator
framework worked for classification of MNIST without modification tobn_decay
parameter.
I've looked through issues on both the tensorflow and models github repositories but haven't found much apart from this. I'm currently experimenting with a lower learning rate and a simpler optimizer (MomentumOptimizer
), but that's more because I'm running out of ideas rather than because I think that's where the problem lies.
Possible Explanations
- The best explanation I have is that my model parameters are rapidly cycling in a manner such that the moving statistics are unable to keep up with the batch statistics. I've never heard of such behaviour, and it doesn't explain why the model reverts to poor behaviour after more time, but it's the best explanation I have.
- There may be a bug in the moving average code, but it has worked perfectly for me in every other case, including a simple classification task. I don't want to file an issue until I can produce a simpler example.
Anyway, I'm running out of ideas, the debug cycle is long, and I've already spent too much time on this. Happy to provide more details or run experiments on demand. Also happy to post more code, though I'm worried that'll scare more people off.
Thanks in advance.