I'm having trouble finding the most elegant way of writing summaries using tf.train.MonitoredTrainingSession(). I'm training a model using an Iterator over a TFRecordsDataSet and everything works fine, my code looked like this:
# m is defined, takes an iterator
with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:
while not sess.should_stop():
sess.run(m.train_op)
But I added a second Iterator that iterates over a validation set. I don't want to train the model on that, I just want to retrieve the accuracy and the loss on the validation data. Here is what I have so far:
it_train_handle = training_dataset.make_one_shot_iterator().string_handle()
it_valid_handle = validation_dataset.make_one_shot_iterator().string_handle()
it_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(it_handle,
training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# model is defined, takes next_element as param
with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:
training_handle = sess.run(it_train_handle.string_handle())
validation_handle = sess.run(it_valid_handle.string_handle())
while not sess.should_stop():
sess.run(m.train_op, feed_dict={it_handle: training_handle})
sess.run([m.accuracy, m.loss], feed_dict={it_handle: validation_handle})
Now, I could probably define two FileWriters and let them write to two different files like in this question but I am pretty sure that there is a better way of doing it (after all, the MonitoredTrainingSession has a default file writer that I never defined, the whole point of using it is that it does that kind of stuff automatically, right?)
I know that tf.train.SummarySaverHook is a thing and I guess that is part of the solution, but how would I tell the session which saver to use?
Any help would be greatly appreciated, thanks already.