I followed Tensorflow's tutorial here to create a custom image segmentation model.
My model looks like the following, and is trying to create masks like these:
However, I am having over-fitting problems so I would like to introduce some Dropout layers or regularization penalties, but I am a little confused as to where in the model I should place them, since the tutorial doesn't really add any. Is this just something I need to mess around with? I am pretty sure the model is over-fitting since the training accuracy increases over 50 epochs while the loss decreases, but the test/validation accuracy goes all over the place and the loss doesn't seem to go down consistently.
Here is a code snippet, in case it's needed:
base_model = MobileNetV2(input_shape=(img_height, img_width, 3), include_top=False)
layer_names = [
'block_1_expand_relu', # 150x200
'block_3_expand_relu', # 75x100
'block_5_expand_relu', # 38x50
'block_7_expand_relu', # 19x25
'block_13_project' # 10x13
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
down_stack = Model(inputs=base_model.input, outputs=base_model_outputs)
down_stack.trainable = False
up_stack = [
pix2pix.upsample(512, 3),
pix2pix.upsample(256, 3),
pix2pix.upsample(128, 3),
pix2pix.upsample(64, 3)
]
inputs = Input(shape=(img_height, img_width, 3))
skips = down_stack(inputs)
x = skips[-1]
skips = reversed(skips[:-1])
for up, skip in zip(up_stack, skips):
x = up(x)
concat = keras.layers.Concatenate()
if x.shape[1]-skip.shape[1] != 0 or x.shape[2]-skip.shape[2] != 0:
x = Cropping2D(cropping=((x.shape[1]-skip.shape[1], 0), (x.shape[2]-skip.shape[2],0)))(x)
x = concat([x, skip])
last = Conv2DTranspose(filters=2, kernel_size=3, strides=2, padding='same')
x = last(x)
model = Model(inputs=inputs, outputs=x)
model.compile(optimizer=Adam(), loss=SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
Thanks in advance for your help.
EDIT: I do have class weights as well (like in the tutorial) since most of the image is the background class. Even so, the model still over-fits.
def add_sample_weights(image: Any, label: Any) -> Tuple[Any, Any, Any]:
class_weights = tf.constant([1.0, 2.0])
class_weights = class_weights / tf.reduce_sum(class_weights)
sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))
return image, label, sample_weights