3

I have seen a few posts on restoring TF models and the Google doc page on exporting graphs but I think I am missing something.

I use the code in this Gist to save the model along with this utils file to which defines the model

Now I would like to restore it and run in a previously unseen test data as follows:

def evaluate(X_data, y_data):
    num_examples = len(X_data)
    total_accuracy = 0
    total_loss = 0
    sess = tf.get_default_session()
    acc_steps = len(X_data) // BATCH_SIZE
    for i in range(acc_steps):
        batch_x, batch_y = next_batch(X_val, Y_val, BATCH_SIZE)

        loss, accuracy = sess.run([loss_value, acc], feed_dict={
                images_placeholder: batch_x,
                labels_placeholder: batch_y,
                keep_prob: 0.5
                })
        total_accuracy += (accuracy * len(batch_x))
        total_loss += (loss * len(batch_x))
    return (total_accuracy / num_examples, total_loss / num_examples)

## re-execute the code that defines the model

# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')

gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')

gray /= 255.

# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')

# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')

# construct model
logits = inference(gray, keep_prob)

# calculate loss
loss_value = loss(logits, labels_placeholder)

# training
train_op = training(loss_value, 0.001)

# accuracy
acc = accuracy(logits, labels_placeholder)

with tf.Session() as sess:
    loader = tf.train.import_meta_graph('gtsd.meta')
    loader.restore(sess, tf.train.latest_checkpoint('./'))
    sess.run(tf.initialize_all_variables())   
    test_accuracy = evaluate(X_test, y_test)
    print("Test Accuracy = {:.3f}".format(test_accuracy[0]))

I'm getting a test accuracy of only 3%. However If I don't close the Notebook and run the test code immediately after training the model, I get a 95% accuracy.

This leads me to believe I'm not loading the model correctly?

Community
  • 1
  • 1
Sam Hammamy
  • 10,819
  • 10
  • 56
  • 94

3 Answers3

5

The problem arises from these two lines:

loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())   

The first line loads the saved model from a checkpoint. The second line re-initializes all of the variables in the model (such as the weight matrices, convolutional filters, and bias vectors), usually to random numbers, and overwrites the loaded values.

The solution is simple: delete the second line (sess.run(tf.initialize_all_variables())) and evaluation will proceed with the trained values loaded from the checkpoint.


PS. There is a small chance that this change will give you an error about "uninitialized variables". In that case, you should execute sess.run(tf.initialize_all_variables()) to initialize any variables not saved in the checkpoint before executing loader.restore(sess, tf.train.latest_checkpoint('./')).

mrry
  • 125,488
  • 26
  • 399
  • 400
  • Thanks @mrry I will try this now – Sam Hammamy Dec 22 '16 at 16:56
  • As you expected, TF throws an error about an un-initialized variable. When I move up that line as you suggested, it still only gives a 2% accuracy, thus it's starting from the beg. – Sam Hammamy Dec 22 '16 at 16:59
  • Oh, I noticed another problem! `tf.train.import_meta_graph()` will load a **second copy** of the model structure into the current graph. If the code before you created the `tf.Session` builds a copy of the graph (including all of the weights), *those* weights will remain uninitialized, and only the weights in the second copy will be restored. There are two ways to deal with this: (1) Instead of using `tf.train.import_meta_graph()`, create a `tf.train.Saver` directly and use it to restore the checkpoint into the initial copy of the graph; or... – mrry Dec 22 '16 at 17:14
  • (2) Avoid building the evaluation graph before using `tf.train.import_meta_graph() and instead use introspection methods such as `tf.get_default_graph().get_operation_by_name()` to look up the loss, accuracy, and placeholder tensors in the original graph. Both approaches might require some restructuring (essentially you have to make sure that the variable names are the same in both the graph and the checkpoint) but I expect option (1) will involve less work. – mrry Dec 22 '16 at 17:17
  • I tried option (1) but it did not solve it. I commented out sess = tf.Session() in ## re-execute the code that defines the model but still no luck. I will try option 2 now – Sam Hammamy Dec 22 '16 at 17:34
  • 1
    @mrry how about the dropout? How is one supposed to reset that to `1.0` at eval time? Will declaring a new `tf.placeholder()` simply work, or should one restore the placeholder from the training? – Nicolai Anton Lynnerup Mar 20 '17 at 19:48
2

I had a similar problem and for me this worked:

with tf.Session() as sess:
    saver=tf.train.Saver(tf.all_variables())
    saver=tf.train.import_meta_graph('model.meta')
    saver.restore(sess,"model")

    test_accuracy = evaluate(X_test, y_test)
somberlain
  • 69
  • 6
1

The answer found here is what ended up working as follows:

save_path = saver.save(sess, '/home/ubuntu/gtsd-12-23-16.chkpt')
print("Model saved in file: %s" % save_path)
## later re-run code that creates the model
# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')

gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')

gray /= 255.

# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')

# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')

# construct model
logits = inference(gray, keep_prob)

# calculate loss
loss_value = loss(logits, labels_placeholder)

# training
train_op = training(loss_value, 0.001)

# accuracy
acc = accuracy(logits, labels_placeholder)

saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, '/home/ubuntu/gtsd-12-23-16.chkpt')
        print("Model restored.")
        test_accuracy = evaluate(X_test, y_test)
        print("Test Accuracy = {:.3f}".format(test_accuracy[0]*100))
Sam Hammamy
  • 10,819
  • 10
  • 56
  • 94