I’m currently building a CNN to classify 2 different classes, starting from CT images (1024 x 1024 px in grey scale). I know I have small datasets: 100 samples for the training/validation set (50 ‘class 0’ and 50 ‘class 1’) and 27 samples (6 and 21) for the test set. After many simulations, this network configuration seems to be the optimal solution:
model = Sequential()
model.add(Conv2D(32, (3,3), activation='relu', input_shape=(1024, 1024, 1)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.7))
model.add(Conv2D(32, (3,3), activation='relu')))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.7))
model.add(Conv2D(32, (3,3), activation='relu')))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dropout(0.7))
model.add(Dense(16, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(8, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(4, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(2, activation='softmax'))
model.summary()
Below, you can see the other settings:
# compile model
model.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
# callbacks
earlyStopping = EarlyStopping(min_delta=0.01, monitor = 'accuracy', patience=40, mode='max', restore_best_weights=True)
reduce_lr = ReduceLROnPlateau()
# Fitting
seed(100)
set_random_seed(100)
model.fit(X_train, y_train, batch_size=10, validation_split=0.1, epochs=50, shuffle = True, callbacks=[reduce_lr, earlyStopping])
As you can see below, the performance seems to be good; indeed, the loss and the validation loss keep decreasing through epochs while accuracies grow more and more.
Epoch 1/50
90/90 [==============================] - 43s 478ms/step - loss: 0.8159 - accuracy: 0.4889 - val_loss: 0.6816 - val_accuracy: 0.6000
Epoch 2/50
90/90 [==============================] - 42s 470ms/step - loss: 0.7045 - accuracy: 0.6000 - val_loss: 0.5591 - val_accuracy: 1.0000
Epoch 3/50
90/90 [==============================] - 42s 467ms/step - loss: 0.6565 - accuracy: 0.6667 - val_loss: 0.4171 - val_accuracy: 1.0000
Epoch 4/50
90/90 [==============================] - 43s 474ms/step - loss: 0.6253 - accuracy: 0.6111 - val_loss: 0.0789 - val_accuracy: 1.0000
Epoch 5/50
90/90 [==============================] - 43s 474ms/step - loss: 0.6086 - accuracy: 0.6889 - val_loss: 0.0068 - val_accuracy: 1.0000
Epoch 6/50
90/90 [==============================] - 43s 474ms/step - loss: 0.6311 - accuracy: 0.6444 - val_loss: 0.0067 - val_accuracy: 1.0000
Epoch 7/50
90/90 [==============================] - 42s 468ms/step - loss: 0.5163 - accuracy: 0.7444 - val_loss: 5.7406e-04 - val_accuracy: 1.0000
Epoch 8/50
90/90 [==============================] - 43s 473ms/step - loss: 0.5752 - accuracy: 0.7222 - val_loss: 7.9747e-05 - val_accuracy: 1.0000
Epoch 9/50
90/90 [==============================] - 43s 475ms/step - loss: 0.5440 - accuracy: 0.7444 - val_loss: 1.9908e-06 - val_accuracy: 1.0000
Epoch 10/50
90/90 [==============================] - 42s 468ms/step - loss: 0.5481 - accuracy: 0.7444 - val_loss: 1.5497e-07 - val_accuracy: 1.0000
Epoch 11/50
90/90 [==============================] - 42s 471ms/step - loss: 0.5032 - accuracy: 0.7778 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 12/50
90/90 [==============================] - 43s 475ms/step - loss: 0.5036 - accuracy: 0.7889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 13/50
90/90 [==============================] - 42s 471ms/step - loss: 0.4714 - accuracy: 0.8222 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 14/50
90/90 [==============================] - 42s 470ms/step - loss: 0.4353 - accuracy: 0.8667 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 15/50
90/90 [==============================] - 42s 467ms/step - loss: 0.4243 - accuracy: 0.8556 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 16/50
90/90 [==============================] - 42s 471ms/step - loss: 0.4046 - accuracy: 0.8444 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 17/50
90/90 [==============================] - 43s 479ms/step - loss: 0.4068 - accuracy: 0.8556 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 18/50
90/90 [==============================] - 43s 476ms/step - loss: 0.4336 - accuracy: 0.8111 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 19/50
90/90 [==============================] - 42s 470ms/step - loss: 0.3909 - accuracy: 0.8778 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 20/50
90/90 [==============================] - 42s 470ms/step - loss: 0.3343 - accuracy: 0.9111 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 21/50
90/90 [==============================] - 43s 475ms/step - loss: 0.4284 - accuracy: 0.8444 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 22/50
90/90 [==============================] - 42s 469ms/step - loss: 0.4517 - accuracy: 0.8111 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 23/50
90/90 [==============================] - 42s 465ms/step - loss: 0.3613 - accuracy: 0.9000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 24/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3489 - accuracy: 0.9111 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 25/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3843 - accuracy: 0.8778 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 26/50
90/90 [==============================] - 42s 471ms/step - loss: 0.3456 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 27/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3643 - accuracy: 0.8778 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 28/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3856 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 29/50
90/90 [==============================] - 42s 471ms/step - loss: 0.3988 - accuracy: 0.8667 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 30/50
90/90 [==============================] - 43s 475ms/step - loss: 0.3948 - accuracy: 0.8778 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 31/50
90/90 [==============================] - 42s 470ms/step - loss: 0.4920 - accuracy: 0.7333 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 32/50
90/90 [==============================] - 42s 472ms/step - loss: 0.3573 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 33/50
90/90 [==============================] - 42s 465ms/step - loss: 0.4097 - accuracy: 0.8333 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 34/50
90/90 [==============================] - 42s 469ms/step - loss: 0.3378 - accuracy: 0.9111 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 35/50
90/90 [==============================] - 42s 472ms/step - loss: 0.3782 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 36/50
90/90 [==============================] - 43s 475ms/step - loss: 0.3891 - accuracy: 0.8333 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 37/50
90/90 [==============================] - 42s 470ms/step - loss: 0.3851 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 38/50
90/90 [==============================] - 42s 472ms/step - loss: 0.3569 - accuracy: 0.9000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 39/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3795 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 40/50
90/90 [==============================] - 42s 466ms/step - loss: 0.4399 - accuracy: 0.7667 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 41/50
90/90 [==============================] - 42s 466ms/step - loss: 0.4298 - accuracy: 0.8556 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 42/50
90/90 [==============================] - 43s 473ms/step - loss: 0.3553 - accuracy: 0.8778 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 43/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3967 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 44/50
90/90 [==============================] - 43s 475ms/step - loss: 0.3783 - accuracy: 0.8556 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 45/50
90/90 [==============================] - 43s 474ms/step - loss: 0.3379 - accuracy: 0.9444 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 46/50
90/90 [==============================] - 43s 473ms/step - loss: 0.3679 - accuracy: 0.9000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 47/50
90/90 [==============================] - 43s 475ms/step - loss: 0.3074 - accuracy: 0.9333 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 48/50
90/90 [==============================] - 43s 475ms/step - loss: 0.3210 - accuracy: 0.9333 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 49/50
90/90 [==============================] - 43s 476ms/step - loss: 0.4226 - accuracy: 0.8556 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Epoch 50/50
90/90 [==============================] - 43s 476ms/step - loss: 0.3811 - accuracy: 0.8889 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
However, when I tried to predict classes for both datasets (training/validation set and test set) I got unexpected and surprising results: the CNN predicts only one class.
model.predict(X_train)
array([[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.]], dtype=float32)
model.predict(X_test)
array([[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.]], dtype=float32)
Here, I found out that this issue can be generated by the BatchNormalization()
layers. The solution in this post was to predict samples with these few lines:
#Prediction on training set
learning_phase = 1
sample_weights_TR = np.ones(100)
ins_TR = [X_train, y_train, sample_weights_TR, learning_phase]
model.test_function(ins_TR)
[0.39475524, 0.8727273] # Average Accuracy = 0.873
#Prediction on test set
sample_weights_TS = np.ones(27)
ins_TS = [X_test, y_test, sample_weights_TS, learning_phase]
model.test_function(ins_TS)
[0.7853608, 0.79562044] # Average Accuracy = 0.796
However, if I launch model.test_function(ins_TR)
or model.test_function(ins_TS)
again and again, the results are always different; accuracies keep going down!
model.test_function(ins_TS)
[0.7853608, 0.79562044]
model.test_function(ins_TS)
[0.88033366, 0.75]
model.test_function(ins_TS)
[0.9703789, 0.7068063]
model.test_function(ins_TS)
[0.86600196, 0.69266057]
model.test_function(ins_TS)
[0.79449946, 0.67755103]
model.test_function(ins_TS)
[0.77942985, 0.6691176]
model.test_function(ins_TS)
[0.83834535, 0.6588629]
Therefore, my questions are:
- Do you have any suggestion to improve the CNN?
- How can I obtain the right predictions, in presence of BatchNormalization() layers?
- With the solution proposed here, how can I get the prediction for each sample (I would assess the CNN sensitivity and specificity)?
Thanks in advance, Mattia