1

I wonder how can I export the estimator and then import it for prediction from MNIST tutorial, Tensorflow's page. Thank you!

Razvan
  • 45
  • 2
  • 9

1 Answers1

2

The Estimator has model_dir args where the model will be saved. So during prediction we use the Estimator and call the predict method which recreates the graph and the checkpoints are loaded.

For the MNIST example, the prediction code would be:

tf.reset_default_graph()

# An input-function to predict the class of new data.
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": eval_data},
    num_epochs=1,
    shuffle=False)

mnist_classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")

#Prediction call
predictions = mnist_classifier.predict(input_fn=predict_input_fn)

pred_class = np.array([p['classes'] for p in predictions]).squeeze()
print(pred_class)

# Output
# [7 2 1 ... 4 5 6]
Vijay Mariappan
  • 16,921
  • 3
  • 40
  • 59
  • how do I know which key words the prediction has? For example, if I want to get the probabilities instead of classes, can I check if probabilities keyword is given? – hoang tran Nov 09 '18 at 16:50