Let's first summarize the whole picture. The model you're trying to build and run is kind of two auto-encoder model, desgined to solve two task simultaneously. So, if we pass an input image, the model would give two output, i.e. road map and centerline map. But first, we need to trian this model with the given dataset, where a image and corresponding road and centerline segmentation mask are present. In simple term, we can frame this problem as a semantic segmentation with 1 input and 2 outputs.
To build up the training data loader using tf.data
API for such task is quite straightforward. However, there can be different approach but the overall setup would be same. About the error you faced regarding concatenating layers, I think that is expected to happen. But according to the figure from the paper, I think you don't need to do that. You can simply pass the feature maps of first network to next. Let's build this project step by step. I'm using TF 2.11, testing on kaggle with P100 GPUs.
Model
Some common blocks of layers.
def ConvBlock(filters, kernel, kernel_initializer, activation, name=None):
if name is None:
name = "ConvBlock" + str(backend.get_uid("ConvBlock"))
def apply(input):
c1 = layers.Conv2D(
filters=filters,
kernel_size=kernel,
padding='same',
kernel_initializer=kernel_initializer,
name=name+'_conv'
)(input)
c1 = layers.BatchNormalization(name=name+'_batch')(c1)
c1 = layers.Activation(activation,name=name+'_active')(c1)
return c1
return apply
def DownConvBlock(filters, kernel, kernel_initializer, activation, name=None):
if name is None:
name = "DownConvBlock" + str(backend.get_uid("DownConvBlock"))
def apply(input):
d1 = layers.Conv2DTranspose(
filters=filters,
kernel_size=kernel,
padding='same',
kernel_initializer=kernel_initializer,
name=name+'_conv'
)(input)
d1 = layers.BatchNormalization(name=name+'_batch')(d1)
d1 = layers.Activation(activation,name=name+'_active')(d1)
return d1
return apply
Sub-model for road mask detection task.
def network_mask(input, activation, kernel_initializer, kernel_size):
# Network 1
# ENCODER
x = input
for fmap in [64, 128, 256, 512]:
x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = layers.MaxPool2D(pool_size=(2,2), strides=None, padding='same')(x)
# DECODER
for fmap in [512, 256, 128, 64]:
x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = layers.Conv2D(
filters=1,
kernel_size=(1,1),
kernel_initializer=kernel_initializer,
activation=None,
)(x)
return x
Sub-model for centerline mask detection task.
def network_centerline(input, activation, kernel_initializer, kernel_size):
# Network 2
# ENCODER
x = input
for fmap in [64, 128, 256]:
x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = ConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = layers.MaxPool2D(pool_size=(2,2), strides=None, padding='same')(x)
# DECODER
for fmap in [256, 128, 64]:
x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = DownConvBlock(fmap, kernel_size, kernel_initializer, activation)(x)
x = layers.Conv2DTranspose(
filters=1,
kernel_size=(1,1),
kernel_initializer=kernel_initializer,
activation=None,
)(x)
return x
Full cascaded network, i.e. CasNet.
def CasNet(activation, kernel_initializer, kernel_size):
input = keras.Input(shape=(img_size, img_size, channel), name='images')
mask_feat = network_mask(input, activation, kernel_initializer, kernel_size)
centerline_feat = network_centerline(
mask_feat, activation, kernel_initializer, kernel_size
)
mask_op = keras.layers.Activation(
'sigmoid', name='mask', dtype=tf.float32
)(mask_feat)
centerline_op = keras.layers.Activation(
'sigmoid', name='centerline', dtype=tf.float32
)(centerline_feat)
model = keras.Model(
inputs={
'images': input
},
outputs={
'mask': mask_op,
'centerline': centerline_op
},
name='CasNet'
)
return model
Data Loader
Augmentation pipelines in keras. In coming days, we can use keras-cv
for this.
set_seed = 101
rand_flip = layers.RandomFlip("horizontal_and_vertical", seed=set_seed)
rand_rote = layers.RandomRotation(factor=0.01, seed=set_seed)
# more: https://keras.io/api/layers/preprocessing_layers/image_augmentation/
def keras_augment(image, label, centerline):
tensors = tf.concat([image, label, centerline], axis=-1)
def apply_augment(x):
x = rand_flip(x)
x = rand_rote(x)
return x
aug_tensors = apply_augment(tensors)
image, label, centerline = tf.split(aug_tensors, [3, 1, 1], axis=-1)
return image, label, centerline
Load the samples (road, mask, centerline). The pixel value of road image is normal RGB color, ranging from 0~255
. And the pixel value of road mask and road centerline are ranging between 0-255
with 3 color channel. We will normalize this values.
def read_files(image_path, mask=False):
image = tf.io.read_file(image_path)
if mask:
image = tf.io.decode_png(image, channels=1, dtype=tf.uint8)
image = tf.image.resize(
images=image,
size=[img_size, img_size],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
)
image = tf.where(image == 255, 1, 0)
image = tf.cast(image, tf.float32)
else:
image = tf.io.decode_png(image, channels=3, dtype=tf.uint8)
image = tf.image.resize(images=image, size=[img_size, img_size])
image = tf.cast(image, tf.float32)
image = image / 255.
return image
def load_data(image_list, label_list, centerline_list):
image = read_files(image_list)
label = read_files(label_list, mask=True)
center = read_files(centerline_list, mask=True)
return image, label, center
Notice here, how we pack (prepare_dict
method below) the data for single input and multi-output. Same thing could be done for multi-input and multi-output or multi-input and single output, etc. Again, as mentioned, there could be different way to load such dataset using the same API but the overall setup would be same. I don't want to mention the possible alternatives to avoid confusion.
def prepare_dict(image_batch, label_batch, centerline_batch):
return {'images': image_batch}, {'mask':label_batch, 'centerline':centerline_batch}
def dataloader(image_list, label_list, center_list, split='train'):
dataset = tf.data.Dataset.from_tensor_slices(
(image_list, label_list, center_list)
)
dataset = dataset.shuffle(batch_size * 8) if split == 'train' else dataset
dataset = dataset.repeat() if split == 'train' else dataset
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(keras_augment) if split == 'train' else dataset
dataset = dataset.map(prepare_dict, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=AUTOTUNE)
return dataset


Compile and Run
Let compile the model with loss and metrics and fit it. For loss and metrics, we will use this library until the keras-cv is ready for segmentation task. See the loss
and metrics
parameters below, we are passing loss and metric function for both of the output of the model. Though we can simply pass one loss/metric method and that would be used for both ouput but it's nice to know that we can pass loss/metric method in such format.
model.compile(
optimizer=keras.optimizers.Adam(
learning_rate=0.0001
),
loss={
'mask':sm.losses.bce_jaccard_loss,
'centerline': sm.losses.binary_focal_jaccard_loss
},
metrics={
'mask': sm.metrics.iou_score,
'centerline': sm.metrics.f1_score
}
)
history = model.fit(
train_ds,
validation_data=valid_ds,
steps_per_epoch=len(train_images_path) // batch_size,
callbacks=my_callbacks,
epochs=epoch
)
...
...
160/160 [==============================] - 186s
loss: 1.0082 - centerline_loss: 0.7613 - mask_loss: 0.2469 -
centerline_f1-score: 0.4074 - mask_iou_score: 0.8115 -
val_loss: 1.2867 - val_centerline_loss: 0.7986 -
val_mask_loss: 0.4882 - val_centerline_f1-score: 0.3572 -
val_mask_iou_score: 0.6860
160/160 [==============================] - 186s 1s/step -
loss: 0.9827 - centerline_loss: 0.7491 - mask_loss: 0.2336 -
centerline_f1-score: 0.4223 - mask_iou_score: 0.8210 -
val_loss: 1.4251 - val_centerline_loss: 0.8222 -
val_mask_loss: 0.6028 - val_centerline_f1-score: 0.3160 -
val_mask_iou_score: 0.6344
...
...



Full Code and Resources
Here is the full code, it's run on kaggle (P100, TF 2.11). Here are some resource that might come handy. Most of them are related to segmentation modeling and about loss method selection.