I have a VAE model that I've broken down into the encoder and decoder parts, and implemented a custom loss. A simplified example is as below
input = Input(shape=(self.image_height, self.image_width, self.image_channel))
encoded = build_encoder(input)
decoded = build_decoder(encoded)
model = Model(input, decoded)
The loss (simplified) is
loss = K.mean(decoded[0] + decoded[1] + encoded[0]**2)
model.add_loss(loss)
model.compile(optimizer=self.optimizer)
My main problem is that I want to use Keras' modelcheckpoint function, which would then require me to set custom metrics. However, everything I have seen online is similar to https://keras.io/metrics/#custom_metrics. This only takes in y_true and y_pred, and modify the validation loss from there. How would I implement it in my example model, where the loss is calculated from multiple inputs, not only the final output of "decoded"?