2

Following a standard prefetching queue, I extended the aformentioned example by some additional validation code, see attached code. That is, every ith training step, the learned model is evaluated on a validation set (several, in my case). The validation set can't be fed through the queue, so one possible idea is to build an additional inference graph using the shared variables.

This somehow works, but after training is finished, the program hangs (at coord.join()) and eventually throws an exception: Coordinator stopped with threads still running:... and then the asynchronous loading thread also throws an exception. The coordinator exception can be tackled via a try/except clause (see code below), but the async thread still throws an exception (which does not hamper the main program, though, but should not happen in my opinon---it has the while loop that should tell it to stop).

Interestingly, if training is done without any evaluation code running (that is, the block after if (it+1)%stop == 0: commented out), then coord.join() does not hang at all.

My question: What am I doing wrong here? It seems as if .request_stop() is not doing what I hope it should?

import tensorflow as tf
import numpy as np

# some parameters
btsz = 100 # batch size
some_shape = 20 # size of one input (no of dims)
iters = 1000 # that many single training steps
ith = 10 # run validation sets every so often
# datastores (sort of complex backends, SQL like)
ds_train = ... # the one for training
ds_val1, ds_val2, ds_val3 = ... # having the validation data

def async_load(coord, session, queue, datastore,
               tf_input, tf_target):
    """
    Feed queue in async way. Inputs can be extracted
    from datastore only one row at a time.
    """
    while not coord.should_stop():
        input = extract_one_input_as_numpy(datastore)
        target = extract_numpy_from(datastore) # either 0 or 1
        session.run(queue, feed_dict={tf_input: input, tf_target: target})

def evaluate(sess, datastore, tf_input, tf_target, tf_loss, btsz):
    """
    Evaluate current model (represented as tf_loss) on a datastore.
    """
    loss = []
    for i in xrange(something):
        input_batch = collect_btsz_many_single examples(datastore)
        target_batch = same_for_targets(datastore)
        tmp, = sess.run([tf_loss], feed_dict={tf_input:input_batch, tf_target:target_batch})
        loss.append(tmp)
    return np.mean(loss)

def log_reg(input, target, W, b):
    """
    Simple logistic regression model.
    """
    y = tf.matmul(input, W) + b
    y_bin = tf.to_int32(y > 0)

    t_bin = tf.to_int32(target > 0)

    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y, targets=target))
    correct_prediction = tf.equal(y_bin, t_bin)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    return y, loss, accuracy

with tf.Session() as sess:
    # Placeholders to represent one input/target pair from a data store.
    ds_inpt = tf.placeholder(tf.float32, shape=[some_shape])
    ds_trgt = tf.placeholder(tf.float32, shape=[])

    queue = tf.FIFOQueue(capacity=10000, dtypes=[tf.float32, tf.float32], 
                  shapes=[[], [some_shape], shared_name="FIFO", name="FIFO")

    # enqueuing, this will be used in the async loading.
    enqueue_op = queue.enqueue([ds_trgt, ds_inpt])

    # dequeue from queue q, with batch size btsz
    q_trgt, q_inpt = queue.dequeue_many(btsz)

    # Paramters for Logistic Regression
    # two functions that build shared variables and initialize these
    W = weight_variable([some_shape, 1])
    b = bias_variable([1])

    # training model, feed from dequeuing the async queue
    y, loss, accuracy = log_reg(input=q_inpt, target=q_trgt, W=W, b=b)

    train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

    # inputs for validation models
    val_inpt = tf.placeholder(tf.float32, shape=[btsz, some_shape])
    val_trgt = tf.placeholder(tf.float32, shape=[btsz])
    # validation model
    val_y, val_loss, val_accuracy = log_reg(input=val_inpt, target=val_trgt, W=W, b=b)

    sess.run(tf.initialize_all_variables())
    try:
        coord = tf.train.Coordinator()
        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        t = threading.Thread(target=async_load,
                              args=(coord, sess, enqueue_op, ds_train 
                                    ds_inpt, ds_trgt))
        t.start()

        # collect loss/accuracy for training
        # and losses for validation/test sets.
        tr_loss = []
        tr_acc = []
        v_loss = []

        for it in xrange(iters):
            _, _loss, _acc = sess.run([train_step, loss, accuracy])
            tr_loss.append(_loss)
            tr_acc.append(_acc)
            if (it+1)%stop == 0:
                # run trained model on validation set 1
                tmp = evaluate(sess=sess, data=ds_val1,
                               tf_inpt=val_inpt, tf_trgt=val_trgt,
                               tf_loss=val_loss, btsz)
                v_loss.append(tmp)
                # run trained model on validation set 2
                tmp = evaluate(sess=sess, data=ds_val2,
                               tf_inpt=val_inpt, tf_trgt=val_trgt,
                               tf_loss=val_loss, btsz)
                v_loss.append(tmp)
                # run trained model on validation set 3
                tmp = evaluate(sess=sess, data=ds_val3,
                               tf_inpt=val_inpt, tf_trgt=val_trgt,
                               tf_loss=val_loss, btsz)
                v_loss.append(tmp)
        coord.request_stop()
        coord.join([t])
    except RuntimeError as rte:
        print("Caught {}".format(rte))
# Clear everything!
tf.reset_default_graph()
Community
  • 1
  • 1
osdf
  • 818
  • 10
  • 20

1 Answers1

5

There is a race condition in your code. The thread running async_load() will block forever if the following events happen:

  1. async_load() calls coord.should_stop() which returns False.
  2. async_load() calls session.run(queue, ...) but the queue is full so the call blocks indefinitely.
  3. Main thread calls coord.request_stop().
  4. Main thread calls coord.join([t]), and this blocks forever because of (2).

One way to avoid this is to create a queue.close(cancel_pending_enqueues=True) op, and run it in the main thread before calling coord.request_stop(). This will unblock the async_load() thread, and enable coord.join([t]) to return.

mrry
  • 125,488
  • 26
  • 399
  • 400
  • This results in a ```CanceledError``` and ```AbortedError``` in the loading thread, which is supposed to happen and simply needs to be caught, correct? – osdf Mar 25 '16 at 10:11
  • Not strictly related to the original question, but is tied to the code: ```q_trgt```, ```q_inpt```, referring to the dequed tensors, could be feed directly for the evaluation part? So if used in a feed dictionary, the queue would not be polled, correct? This would avoid the additional ```val_trgt```, ```val_inpt``` placeholders. – osdf Mar 25 '16 at 10:16
  • Thank you for the swift answer!! :) – osdf Mar 25 '16 at 10:16