I am playing with a naive U-net that I'm deploying on MNIST as a toy dataset.
I am seeing a strange behaviour in the way the from_logits
argument works in tf.keras.losses.BinaryCrossentropy
.
From what I understand, if in the last layer of any neural network activation='sigmoid'
is used, then in tf.keras.losses.BinaryCrossentropy
you must use from_logits=False
. If instead activation=None
, you need from_logits=True
. Either of them should work in practice, although from_logits=True
appears more stable (e.g., Why does sigmoid & crossentropy of Keras/tensorflow have low precision?). This is not the case in the following example.
So, my unet
goes as follows (the full code is at the end of this post):
def unet(input,init_depth,activation):
# do stuff that defines layers
# last layer is a 1x1 convolution
output = tf.keras.layers.Conv2D(1,(1,1), activation=activation)(previous_layer) # shape = (28x28x1)
return tf.keras.Model(input,output)
Now I define two models, one with the activation in the last layer:
input = Layers.Input((28,28,1))
model_withProbs = unet(input,4,activation='sigmoid')
model_withProbs.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
optimizer=tf.keras.optimizers.Adam()) #from_logits=False since the sigmoid is already present
and one without
model_withLogits = unet(input,4,activation=None)
model_withLogits.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam()) #from_logits=True since there is no activation
If I'm right, they should have exactly the same behaviour.
Instead, the prediction for model_withLogits
has pixel values up to 2500 or so (which is wrong), while for model_withProbs
I get values between 0 and 1 (which is right). You can check out the figures I get here
I thought about the issue of stability (from_logits=True
is more stable) but this problem appears even before training (see here). Moreover, the problem is exactly when I pass from_logits=True
(that is, for model_withLogits
) so I don't think stability is relevant.
Does anybody have any clue of why this is happening? Am I missing anything fundamental here?
Post Scriptum: Codes
Re-purposing MNIST for segmentation.
I load MNIST:
(x_train, labels_train), (x_test, labels_test) = tf.keras.datasets.mnist.load_data()
I am re-purposing MNIST for a segmentation task by setting to one all the non-zero values x_train
:
x_train = x_train/255 #normalisation
x_test = x_test/255
Y_train = np.zeros(x_train.shape) #create segmentation map
Y_train[x_train>0] = 1 #Y_train is zero everywhere but where the digit is drawn
Full unet
network:
def unet(input, init_depth,activation):
conv1 = Layers.Conv2D(init_depth,(2,2),activation='relu', padding='same')(input)
pool1 = Layers.MaxPool2D((2,2))(conv1)
drop1 = Layers.Dropout(0.2)(pool1)
conv2 = Layers.Conv2D(init_depth*2,(2,2),activation='relu',padding='same')(drop1)
pool2 = Layers.MaxPool2D((2,2))(conv2)
drop2 = Layers.Dropout(0.2)(pool2)
conv3 = Layers.Conv2D(init_depth*4, (2,2), activation='relu',padding='same')(drop2)
#pool3 = Layers.MaxPool2D((2,2))(conv3)
#drop3 = Layers.Dropout(0.2)(conv3)
#upsampling
up1 = Layers.Conv2DTranspose(init_depth*2, (2,2), strides=(2,2))(conv3)
up1 = Layers.concatenate([conv2,up1])
conv4 = Layers.Conv2D(init_depth*2, (2,2), padding='same')(up1)
up2 = Layers.Conv2DTranspose(init_depth,(2,2), strides=(2,2), padding='same')(conv4)
up2 = Layers.concatenate([conv1,up2])
conv5 = Layers.Conv2D(init_depth, (2,2), padding='same' )(up2)
last = Layers.Conv2D(1,(1,1), activation=activation)(conv5)
return tf.keras.Model(inputs=input,outputs=last)