I'm looking at this answer for running evaluation metrics during training:
How to use evaluation_loop with train_loop in tf-slim
and it seems like overriding train_step_fn=train_step_fn
is the reasonable approach. But I want to run a validation loop, not evaluation. My graph is something like this:
with tf.Graph().as_default():
train_dataset = slim.dataset.Dataset(data_sources= "train_*.tfrecord")
train_images, _, train_labels = load_batch(train_dataset,
batch_size=mini_batch_size,
is_training=True)
val_dataset = slim.dataset.Dataset(data_sources= "validation_*.tfrecord")
val_images, _, val_labels = load_batch(val_dataset,
batch_size=mini_batch_size,
is_training=False)
with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0005)):
net, end_points = vgg.vgg_16(train_images,
num_classes=10,
is_training=is_training)
predictions = tf.nn.softmax(net)
labels = train_labels
...
init_fn = slim.assign_from_checkpoint_fn(
checkpoint_path,
slim.get_variables_to_restore(exclude=['vgg_16/fc8']),
ignore_missing_vars=True
)
final_loss = slim.learning.train(train_op, TRAIN_LOG,
train_step_fn=train_step_fn,
init_fn=init_fn,
global_step=global_step,
number_of_steps=steps,
save_summaries_secs=60,
save_interval_secs=600,
session_config=sess_config,
)
I want to add something like this to do a validation loop with a mini-batch against the current weights for the network
def validate_on_checkpoint(sess, *args, **kwargs ):
loss,mean,stddev = sess.run([val_loss, val_rms_mean, val_rms_stddev],
feed_dict={images: val_images,
labels: val_labels,
is_training: is_training })
validation_writer = tf.train.SummaryWriter(LOG_DIR + '/validation')
validation_writer.add_summary(loss, global_step)
validation_writer.add_summary(mean, global_step)
validation_writer.add_summary(stddev, global_step)
def train_step_fn(sess, *args, **kwargs):
total_loss, should_stop = train_step(sess, *args, **kwargs)
if train_step_fn.step % FLAGS.validation_every_n_step == 0:
validate_on_checkpoint(sess, *args, **kwargs )
train_step_fn.step += 1
return [total_loss, should_stop]
but I got an error=Graph is finalized and cannot be modified.
Conceptually I'm not sure how I should add this. The training
loop needs to needs gradients, dropouts, and weight updates for the net, but the validation
loop skips all of that. I keep getting variations on Graph is finalized and cannot be modified.
if I try to modify the Graph or XXX is not defined
if I use an if is_training: else:
approach