Following a standard prefetching queue, I extended the aformentioned example by some additional validation code, see attached code. That is, every ith training step, the learned model is evaluated on a validation set (several, in my case). The validation set can't be fed through the queue, so one possible idea is to build an additional inference graph using the shared variables.
This somehow works, but after training is finished, the program hangs (at coord.join()
) and eventually throws an exception: Coordinator stopped with threads still running:...
and then the asynchronous loading thread also throws an exception. The coordinator
exception can be tackled via a try/except
clause (see code below), but the async thread still throws an exception (which does not hamper the main program, though, but should not happen in my opinon---it has the while
loop that should tell it to stop).
Interestingly, if training is done without any evaluation code running (that is, the block after if (it+1)%stop == 0:
commented out), then coord.join()
does not hang at all.
My question: What am I doing wrong here? It seems as if .request_stop()
is not doing what I hope it should?
import tensorflow as tf
import numpy as np
# some parameters
btsz = 100 # batch size
some_shape = 20 # size of one input (no of dims)
iters = 1000 # that many single training steps
ith = 10 # run validation sets every so often
# datastores (sort of complex backends, SQL like)
ds_train = ... # the one for training
ds_val1, ds_val2, ds_val3 = ... # having the validation data
def async_load(coord, session, queue, datastore,
tf_input, tf_target):
"""
Feed queue in async way. Inputs can be extracted
from datastore only one row at a time.
"""
while not coord.should_stop():
input = extract_one_input_as_numpy(datastore)
target = extract_numpy_from(datastore) # either 0 or 1
session.run(queue, feed_dict={tf_input: input, tf_target: target})
def evaluate(sess, datastore, tf_input, tf_target, tf_loss, btsz):
"""
Evaluate current model (represented as tf_loss) on a datastore.
"""
loss = []
for i in xrange(something):
input_batch = collect_btsz_many_single examples(datastore)
target_batch = same_for_targets(datastore)
tmp, = sess.run([tf_loss], feed_dict={tf_input:input_batch, tf_target:target_batch})
loss.append(tmp)
return np.mean(loss)
def log_reg(input, target, W, b):
"""
Simple logistic regression model.
"""
y = tf.matmul(input, W) + b
y_bin = tf.to_int32(y > 0)
t_bin = tf.to_int32(target > 0)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y, targets=target))
correct_prediction = tf.equal(y_bin, t_bin)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
return y, loss, accuracy
with tf.Session() as sess:
# Placeholders to represent one input/target pair from a data store.
ds_inpt = tf.placeholder(tf.float32, shape=[some_shape])
ds_trgt = tf.placeholder(tf.float32, shape=[])
queue = tf.FIFOQueue(capacity=10000, dtypes=[tf.float32, tf.float32],
shapes=[[], [some_shape], shared_name="FIFO", name="FIFO")
# enqueuing, this will be used in the async loading.
enqueue_op = queue.enqueue([ds_trgt, ds_inpt])
# dequeue from queue q, with batch size btsz
q_trgt, q_inpt = queue.dequeue_many(btsz)
# Paramters for Logistic Regression
# two functions that build shared variables and initialize these
W = weight_variable([some_shape, 1])
b = bias_variable([1])
# training model, feed from dequeuing the async queue
y, loss, accuracy = log_reg(input=q_inpt, target=q_trgt, W=W, b=b)
train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
# inputs for validation models
val_inpt = tf.placeholder(tf.float32, shape=[btsz, some_shape])
val_trgt = tf.placeholder(tf.float32, shape=[btsz])
# validation model
val_y, val_loss, val_accuracy = log_reg(input=val_inpt, target=val_trgt, W=W, b=b)
sess.run(tf.initialize_all_variables())
try:
coord = tf.train.Coordinator()
# Start a thread to enqueue data asynchronously, and hide I/O latency.
t = threading.Thread(target=async_load,
args=(coord, sess, enqueue_op, ds_train
ds_inpt, ds_trgt))
t.start()
# collect loss/accuracy for training
# and losses for validation/test sets.
tr_loss = []
tr_acc = []
v_loss = []
for it in xrange(iters):
_, _loss, _acc = sess.run([train_step, loss, accuracy])
tr_loss.append(_loss)
tr_acc.append(_acc)
if (it+1)%stop == 0:
# run trained model on validation set 1
tmp = evaluate(sess=sess, data=ds_val1,
tf_inpt=val_inpt, tf_trgt=val_trgt,
tf_loss=val_loss, btsz)
v_loss.append(tmp)
# run trained model on validation set 2
tmp = evaluate(sess=sess, data=ds_val2,
tf_inpt=val_inpt, tf_trgt=val_trgt,
tf_loss=val_loss, btsz)
v_loss.append(tmp)
# run trained model on validation set 3
tmp = evaluate(sess=sess, data=ds_val3,
tf_inpt=val_inpt, tf_trgt=val_trgt,
tf_loss=val_loss, btsz)
v_loss.append(tmp)
coord.request_stop()
coord.join([t])
except RuntimeError as rte:
print("Caught {}".format(rte))
# Clear everything!
tf.reset_default_graph()