This is the first time I'm experiencing this issue. I've been using this model for a while, but with less data. The problem is that in the first 3 epochs training took 11 sec/step (31k samples / 128 batch size) while in the 4-th epoch it took 18 sec/step. In the fifth it took about 45 sec/step. I'm using Keras and not doing any custom loop shenanigans.
Can someone explain this slowdown? The model hasn't been interrupted. I'm using TF 2.3
Epoch 1/1200
248/248 [==============================] - 2727s 11s/step - loss: 2.3481 - acc: 0.3818 - top3_acc: 0.5751 - recall: 0.2228 - precision: 0.6195 - f1: 0.3239 - val_loss: 0.9020 - val_acc: 0.8085 - val_top3_acc: 0.8956 - val_recall: 0.5677 - val_precision: 0.9793 - val_f1: 0.7179
Epoch 2/1200
248/248 [==============================] - 2712s 11s/step - loss: 1.0319 - acc: 0.7203 - top3_acc: 0.8615 - recall: 0.5489 - precision: 0.9245 - f1: 0.6865 - val_loss: 0.5547 - val_acc: 0.8708 - val_top3_acc: 0.9371 - val_recall: 0.7491 - val_precision: 0.9661 - val_f1: 0.8435
Epoch 3/1200
248/248 [==============================] - 4426s 18s/step - loss: 0.7094 - acc: 0.8093 - top3_acc: 0.9178 - recall: 0.6830 - precision: 0.9446 - f1: 0.7920 - val_loss: 0.4399 - val_acc: 0.8881 - val_top3_acc: 0.9567 - val_recall: 0.8140 - val_precision: 0.9606 - val_f1: 0.8808
Epoch 4/1200
18/248 [=>............................] - ETA: 3:14:16 - loss: 0.6452 - acc: 0.8338 - top3_acc: 0.9223 - recall: 0.7257 - precision: 0.9536 - f1: 0.8240
Edit: I just ran this on a super small sample(20 items / category) of the data and the step time does not increase. proof
Edit 2: Model summary
Model: "functional_3"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_token (InputLayer) [(None, 300)] 0
__________________________________________________________________________________________________
masked_token (InputLayer) multiple 0 input_token[0][0]
__________________________________________________________________________________________________
tf_distil_bert_model (TFDistilB ((None, 300, 768),) 66362880 masked_token[1][0]
__________________________________________________________________________________________________
tf_op_layer_strided_slice (Tens multiple 0 tf_distil_bert_model[1][0]
__________________________________________________________________________________________________
efficientnetb5_input (InputLaye [(None, 456, 456, 3) 0
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 768) 3072 tf_op_layer_strided_slice[1][0]
__________________________________________________________________________________________________
efficientnetb5 (Functional) (None, 15, 15, 2048) 28513527 efficientnetb5_input[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 256) 196864 batch_normalization[1][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 2048) 0 efficientnetb5[1][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 140) 35980 dense[1][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 140) 286860 global_average_pooling2d[1][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 280) 0 dense_1[1][0]
dense_3[1][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 100) 28100 concatenate[0][0]
__________________________________________________________________________________________________
dropout_20 (Dropout) (None, 100) 0 dense_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 20) 2020 dropout_20[0][0]
==================================================================================================
Total params: 95,429,303
Trainable params: 30,120
Non-trainable params: 95,399,183