3

I am training a deep network implementing the tf.keras.Model API using tf.keras.Model.fit() (Tensorflow 2.0) to segment images. The input data comes from a dataset using tf.data.TFRecordDataset. Thus, the fit() function takes care of loading the next batch, compute the predictions of the current batch, and computing the corresponding loss. The training appears to run well, judging by the outputs in the terminal. Now, what I would like to do is to display the current training batch, including the corresponding network predictions, on Tensorboard, e.g. every 100 training iterations or so.

Tensorflow's main way to modify/interact with the training iterations implemented within fit() is the use of tf.keras.Callbacks; for instance, there is a dedicated tf.keras.callbacks.TensorBoard callback to include tensorboard functionality within the training. This makes it very easy to plot error metrics during the training. However, as far as I can see, this callback is completely useless for displaying the current batch/predictions, as neither the actual input image nor the predictions are ever passed to the callbacks by fit(); all that is ever passed is the current batch index, as well as the current error metrics (right?)

Now, a kind of brute-force solution could be as follows: I could create 2 copies of the input training dataset; let's denote those training_dset and training_dset_copy, where I could pass training_dset to the fit() function. Then, I could pass a TensorBoard callback to fit() which would, at every iteration take() a batch from training_dset_copy, call model.predict() on that batch and then plot the images to tensorboard. Since training_dset_copy is an exact copy of training_dset, the take() operation would access the same batch as actually used in the training. However, having 2 copies of the same dataset, only to plot some batches in tensorboard, seems to me to be a rather wasteful solution, and I am convinced that there must be a more elegant solution to achieve this.

There are somewhat related posts on StackOverflow, such as e.g. in TensorFlow 2.0 Keras: How to write image summaries for TensorBoard, however, that post does not reveal how to also show the predictions of the current batch in tensorboard. Furthermore, there is https://www.tensorflow.org/tensorboard/r2/image_summaries, which is a tutorial on showing images on tensorboard. However, that tutorial assumes that I already have the input image batch as e.g. a numpy array. My question, however, in the end boils down to accessing the current batch used during fit() in the first place, so that I can then show it on tensorboard

For simplicity, and also for consistency with the previously referenced StackOverflow post, I want to reuse their MNIST example:

import tensorflow as tf
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()

def scale(image, label):
    return tf.cast(image, tf.float32) / 255.0, label

def augment(image, label):
    return image, label  # do nothing atm

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.map(scale).map(augment).batch(32)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(dataset, epochs=5, callbacks=[tf.keras.callbacks.TensorBoard(log_dir='D:\\tmp\\test')])

Can anyone show me a good way how to make the above code plot the current training batch, incl. corresponding predictions, on tensorboard?

  • "right?": right. – bers Jan 09 '20 at 00:59
  • If you don't rely on eager mode, this has been proposed for TF1 ( https://stackoverflow.com/questions/47079111/) and works in TF2 if you disable eager mode (https://stackoverflow.com/questions/58229537). This allows you to access the current batch. You need to combine it with the code to write an image summary to TB, which is more difficult than it looks because many code examples rely on eager mode. But it's a start. – bers Jan 09 '20 at 01:01

0 Answers0