I am training a deep network implementing the tf.keras.Model
API using tf.keras.Model.fit()
(Tensorflow 2.0) to segment images. The input data comes from a dataset using tf.data.TFRecordDataset
. Thus, the fit()
function takes care of loading the next batch, compute the predictions of the current batch, and computing the corresponding loss. The training appears to run well, judging by the outputs in the terminal. Now, what I would like to do is to display the current training batch, including the corresponding network predictions, on Tensorboard, e.g. every 100 training iterations or so.
Tensorflow's main way to modify/interact with the training iterations implemented within fit()
is the use of tf.keras.Callbacks
; for instance, there is a dedicated tf.keras.callbacks.TensorBoard
callback to include tensorboard functionality within the training. This makes it very easy to plot error metrics during the training. However, as far as I can see, this callback is completely useless for displaying the current batch/predictions, as neither the actual input image nor the predictions are ever passed to the callbacks by fit()
; all that is ever passed is the current batch index, as well as the current error metrics (right?)
Now, a kind of brute-force solution could be as follows: I could create 2 copies of the input training dataset; let's denote those training_dset
and training_dset_copy
, where I could pass training_dset
to the fit()
function. Then, I could pass a TensorBoard
callback to fit() which would, at every iteration take()
a batch from training_dset_copy
, call model.predict()
on that batch and then plot the images to tensorboard. Since training_dset_copy
is an exact copy of training_dset
, the take() operation would access the same batch as actually used in the training. However, having 2 copies of the same dataset, only to plot some batches in tensorboard, seems to me to be a rather wasteful solution, and I am convinced that there must be a more elegant solution to achieve this.
There are somewhat related posts on StackOverflow, such as e.g. in TensorFlow 2.0 Keras: How to write image summaries for TensorBoard, however, that post does not reveal how to also show the predictions of the current batch in tensorboard. Furthermore, there is https://www.tensorflow.org/tensorboard/r2/image_summaries, which is a tutorial on showing images on tensorboard. However, that tutorial assumes that I already have the input image batch as e.g. a numpy array. My question, however, in the end boils down to accessing the current batch used during fit()
in the first place, so that I can then show it on tensorboard
For simplicity, and also for consistency with the previously referenced StackOverflow post, I want to reuse their MNIST example:
import tensorflow as tf
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
def scale(image, label):
return tf.cast(image, tf.float32) / 255.0, label
def augment(image, label):
return image, label # do nothing atm
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.map(scale).map(augment).batch(32)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(dataset, epochs=5, callbacks=[tf.keras.callbacks.TensorBoard(log_dir='D:\\tmp\\test')])
Can anyone show me a good way how to make the above code plot the current training batch, incl. corresponding predictions, on tensorboard?