I am trying to optimize a model with the following two loss functions
def loss_1(pred, weights, logits):
weighted_sparse_ce = kls.SparseCategoricalCrossentropy(from_logits=True)
policy_loss = weighted_sparse_ce(pred, logits, sample_weight=advantages)
and
def loss_2(y_pred, y):
return kls.mean_squared_error(y_pred, y)
however, because TensorFlow 2 expects loss function to be of the form
def fn(y_pred, y_true):
...
I am using a work-around for loss_1
where I pack pred
and weights
into a single tensor before passing to loss_1
in the call to model.fit
and then unpack them in loss_1
. This is inelegant and nasty because pred
and weights
are of different data types and so this requires an additional cast, pack, un-pack and un-cast each time I call model.fit
.
Furthermore, I am aware of the sample_weight
argument to fit
, which is kind of like the solution to this question. This might be a workable solution were it not for the fact that I am using two loss functions and I only want the sample_weight
applied to one of them. Also, even if this were a solution, would it not be generalizable to other types of custom loss functions.
All that being said, my question, said concisely, is:
What is the best way to create a loss function with an arbitrary number of arguments in TensorFlow 2?
Another thing I have tried is passing a tf.tuple
but that also seems to violate TensorFlow's desires for a loss function input.