How do I save and load a DNN classifier in tensorflow? Asking this for the default Iris classifier program given. (https://www.tensorflow.org/get_started/estimator)
-
Here is detailed example with latest tensorflow version 1.7 https://stackoverflow.com/a/52222383/5904928 – Aaditya Ura Sep 07 '18 at 12:22
2 Answers
To save and reuse the classifier you can just reload it with the same model_dir path.
For example in the method you want to use the classifier you can just create the classifier again with the same model_dir. This will reload it from what ever state it was previously.
I use this for training and then reload it for testing single examples.
tf.estimator.DNNClassifier
(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")

- 55
- 1
- 7
Save
The first thing you need to do is to create a tensorflow Saver object inside your session:
with tf.Session(graph=graph) as sess:
saver = tf.train.Saver()
Then, after your training - and still inside the session -, you call the save method:
saver.save(sess, 'path/to/model_file')
You don't need to specify file extension since the save method will do it for you.
Load
To restore the model, you open a new session (without a graph, of course) and do like this:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('path/to/model_file.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('path/to/model_dir/'))
# restore the tensors you want (usually, the ones you use in feed_dict and sess.run)
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
output = graph.get_tensor_by_name("output:0")
feed_dict = {x:x}
[result] = sess.run([output], feed_dict=feed_dict)
You can also check this tutorial about saving and restoring tensorflow models. I hope it helps!

- 334
- 3
- 13
-
This is the program. Just tell me how to persist the model that I create using- classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model") https://www.tensorflow.org/get_started/estimator#tf-contrib-learn-quickstart – Abhisek Roy Jan 04 '18 at 05:48