You can specify the loss in a keras tensorflow model in two ways. You can use add_loss
and you can also specify the loss in compile
's loss
argument. Since the gradient is taken with respect to some loss in order to do the weight updates, I would imagine that there needs to be a single function somehow combining those losses into one. Are they just added together?
For example, let's say I have the following model. The only important lines are
self.add_loss(kl_loss)
and autoencoder.compile(optimizer=optimizer, loss=r_loss, metrics=[r_loss])
.
class Autoencoder(Model):
def __init__(self):
super(Autoencoder, self).__init__()
encoder_input = layers.Input(shape=INPUT_SHAPE, name='encoder_input')
x = encoder_input
# ...
x = layers.Flatten()(x)
mu = layers.Dense(LATENT_DIM, name='mu')(x)
log_var = layers.Dense(LATENT_DIM, name='log_var')(x)
def sample(args):
mu, log_var = args
epsilon = tf.random.normal(shape=K.shape(mu), mean=0., stddev=1.)
return mu + tf.math.exp(log_var / 2) * epsilon
encoder_output = layers.Lambda(sample, name='encoder_output')([mu, log_var])
self.encoder = Model(encoder_input, outputs=[encoder_output, mu, log_var])
self.decoder = tf.keras.Sequential([
layers.Input(shape=LATENT_DIM),
# ...
def call(self, x):
encoded, mu, log_var = self.encoder(x)
kl_loss = tf.math.reduce_mean(-0.5 * tf.math.reduce_sum(1 + log_var - tf.math.square(mu) - tf.math.exp(log_var)))
self.add_loss(kl_loss)
decoded = self.decoder(encoded)
return decoded
def train_autoencoder():
autoencoder = Autoencoder()
def r_loss(y_true, y_pred):
return tf.math.reduce_sum(tf.math.square(y_true - y_pred), axis=[1, 2, 3])
optimizer = tf.keras.optimizers.Adam(1e-4)
autoencoder.compile(optimizer=optimizer, loss=r_loss, metrics=[r_loss])
When I train my model, I see the following values:
Epoch 00001: saving model to models/autoencoder/cp-autoencoder.ckpt
1272/1272 [==============================] - 249s 191ms/step - batch: 635.5000 - size: 1.0000 - loss: 5300.4540 - r_loss: 2856.8228
Both losses go down together. What exactly is the loss
in the above snippet?