I have a model with multiple outputs, the losses for each output can have dependencies with one of the other outputs, as well as some masks computed from the data. The overall loss of the model is a weighted sum over the losses.
My model is subclassing tf.keras.Model
and I am trying to write clean code that I can use with compile
and fit
. I would like the weights of the losses to be given during the compile.
One way I have found addressing the loss dependencies issue (after reading some documentation and this answer) is to feed the masks type of data as input of the model and, in the implementation of call
, to add the loss of each output with Model.add_loss
. Can someone confirm me this? How do I get y_true
from there?
If this is a good solution, how do I specify that the overall model loss is a weighted sum of those losses during the compile
, how can I access them?
Also would it be better to use add_loss
on each layer in the implementation of the model's call
? Same question, how do I access them during the compile
?
If this was not a good solution, what is a good one?
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self, base_trainable=False, feature_extractor=None, n=4, *kwargs):
super(MyModel, self).__init__(*kwargs)
if feature_extractor:
self.feature_extractor = feature_extractor
else:
feature_extractor = tf.keras.applications.Resnet101(include_top=False,
weights='imagenet',
trainable=base_trainable)
self.out1 = layers.Conv2D(n, kernel_size=(1,1), activation='sigmoid', name='out1')
self.out2 = layers.Conv2D(n, kernel_size=(1,1), name='out2')
self.out3 = layers.Conv2D(2*n, kernel_size=(1,1), 'out3')
def call(self, inputs):
img, mask1, mask2 = inputs
x = self.feature_extractor(img)
out1 = self.out1(x)
out2 = self.out2(x)
out3 = self.out3(x)
# compute losses for each output? (but how do I access to each y_true?...)
# ex:
#
# model.add_loss(my_loss_for_out1(y1_true??
# out1,
# out2))
# model.add_loss(my_loss_for_out1(y2_true??
# out2,
# mask1))
# model.add_loss(my_loss_for_out1(y3_true??
# out3,
# mask2))
return out1, out2, out3
model = MyModel()
model.compile(loss=???
loss_weights=???)
Thank you