2

I followed Tensorflow's tutorial here to create a custom image segmentation model.

My model looks like the following, and is trying to create masks like these: Model Image

Normal Image Masked Image

However, I am having over-fitting problems so I would like to introduce some Dropout layers or regularization penalties, but I am a little confused as to where in the model I should place them, since the tutorial doesn't really add any. Is this just something I need to mess around with? I am pretty sure the model is over-fitting since the training accuracy increases over 50 epochs while the loss decreases, but the test/validation accuracy goes all over the place and the loss doesn't seem to go down consistently.

Here is a code snippet, in case it's needed:

base_model = MobileNetV2(input_shape=(img_height, img_width, 3), include_top=False)

layer_names = [
    'block_1_expand_relu',  # 150x200
    'block_3_expand_relu',  # 75x100
    'block_5_expand_relu',  # 38x50
    'block_7_expand_relu',  # 19x25
    'block_13_project'  # 10x13
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

down_stack = Model(inputs=base_model.input, outputs=base_model_outputs)
down_stack.trainable = False

up_stack = [
    pix2pix.upsample(512, 3),
    pix2pix.upsample(256, 3),
    pix2pix.upsample(128, 3),
    pix2pix.upsample(64, 3)
]

inputs = Input(shape=(img_height, img_width, 3))

skips = down_stack(inputs)
x = skips[-1]
skips = reversed(skips[:-1])

for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = keras.layers.Concatenate()
    if x.shape[1]-skip.shape[1] != 0 or x.shape[2]-skip.shape[2] != 0:
        x = Cropping2D(cropping=((x.shape[1]-skip.shape[1], 0), (x.shape[2]-skip.shape[2],0)))(x)
    x = concat([x, skip])

last = Conv2DTranspose(filters=2, kernel_size=3, strides=2, padding='same')
x = last(x)

model = Model(inputs=inputs, outputs=x)
model.compile(optimizer=Adam(), loss=SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

Thanks in advance for your help.

EDIT: I do have class weights as well (like in the tutorial) since most of the image is the background class. Even so, the model still over-fits.

def add_sample_weights(image: Any, label: Any) -> Tuple[Any, Any, Any]:
    class_weights = tf.constant([1.0, 2.0])
    class_weights = class_weights / tf.reduce_sum(class_weights)
    sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))
    return image, label, sample_weights
firsttry
  • 84
  • 4

0 Answers0