2

When a custom loss is defined in a Keras model, online sources seem to indicate that the the loss should return an array of values (a loss for each sample in the batch). Something like this

def custom_loss_function(y_true, y_pred):
   squared_difference = tf.square(y_true - y_pred)
   return tf.reduce_mean(squared_difference, axis=-1)

model.compile(optimizer='adam', loss=custom_loss_function)

In the example above, I have no idea when or if the model is taking the batch sum or mean with tf.reduce_sum() or tf.reduce_mean()

In another situation when we want to implement a custom training loop with a custom function, the template to follow according to Keras documentation is this

for epoch in range(epochs):
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        with tf.GradientTape() as tape:
            y_batch_pred = model(x_batch_train, training=True)  
            loss_value = custom_loss_function(y_batch_train, y_batch_pred)

        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

So by the book, if I understand correctly, we are supposed to take the mean of the batch gradients. Therefore, the loss value above should be a single value per batch.

However, the example will work with both of the following variations:

  • tf.reduce_mean(squared_difference, axis=-1) # array of loss for each sample
  • tf.reduce_mean(squared_difference) # mean loss for batch

So, why does the first option (array loss) above still work? Is apply_gradients applying small changes for each value sequentially? Is this wrong although it works?

What is the correct way without a custom loop, and with a custom loop?

Edv Beq
  • 910
  • 3
  • 18
  • 43

1 Answers1

1

Good question. In my opinion, this is not well documented in the TensorFlow/Keras API. By default, if you do not provide a scalar loss_value, TensorFlow will add them up (and the updates are not sequential). Essentially, this is equivalent to summing the losses along the batch axis.

Currently, the losses in the TensorFlow API include a reduction argument (for example, tf.losses.MeanSquaredError) that allows specifying how to aggregate the loss along the batch axis.

rvinas
  • 11,824
  • 36
  • 58