1

I'm having trouble finding the most elegant way of writing summaries using tf.train.MonitoredTrainingSession(). I'm training a model using an Iterator over a TFRecordsDataSet and everything works fine, my code looked like this:

 # m is defined, takes an iterator
 with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:
    while not sess.should_stop():
        sess.run(m.train_op)

But I added a second Iterator that iterates over a validation set. I don't want to train the model on that, I just want to retrieve the accuracy and the loss on the validation data. Here is what I have so far:

it_train_handle = training_dataset.make_one_shot_iterator().string_handle()
it_valid_handle = validation_dataset.make_one_shot_iterator().string_handle()

it_handle = tf.placeholder(tf.string, shape=[])

iterator = tf.data.Iterator.from_string_handle(it_handle,
    training_dataset.output_types, training_dataset.output_shapes)

next_element = iterator.get_next()

# model is defined, takes next_element as param 

with tf.train.MonitoredTrainingSession(checkpoint_dir=path) as sess:

    training_handle = sess.run(it_train_handle.string_handle())
    validation_handle = sess.run(it_valid_handle.string_handle())

    while not sess.should_stop():
        sess.run(m.train_op, feed_dict={it_handle: training_handle})
        sess.run([m.accuracy, m.loss], feed_dict={it_handle: validation_handle})

Now, I could probably define two FileWriters and let them write to two different files like in this question but I am pretty sure that there is a better way of doing it (after all, the MonitoredTrainingSession has a default file writer that I never defined, the whole point of using it is that it does that kind of stuff automatically, right?)

I know that tf.train.SummarySaverHook is a thing and I guess that is part of the solution, but how would I tell the session which saver to use?

Any help would be greatly appreciated, thanks already.

Thomas
  • 395
  • 8
  • 21

0 Answers0