I want to save a Tensorflow (0.12.0) model, including graph and variable values, then later load and execute it. I have the read the docs and other posts on this but cannot get the basics to work. I am using the technique from this page in the Tensorflow docs. Code:
Save a simple model:
myVar = tf.Variable(7.1)
tf.add_to_collection('modelVariables', myVar) # why?
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0 = tf.train.Saver()
saver0.save(sess, './myModel.ckpt')
saver0.export_meta_graph('./myModel.meta')
Later, load and execute the model:
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, './myModel.meta')
print sess.run(myVar)
Question 1: The saving code seems to work but the loading code produces this error:
W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open ./myModel.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
How to fix this?.
Question 2: I included this line to follow the pattern in the TF docs...
tf.add_to_collection('modelVariables', myVar)
... but why is that line necessary? Doesn't expert_meta_graph
export the entire graph by default? If not then does one need to add every variable in the graph to the collection before saving? Or do we just add to the collection those variables that will be accessed after the restore?
---------------------- Update January 12 2017 -----------------------------
Partial success based on Kashyap's suggestion below but a mystery still exists. The code below works but only if I include the lines containing tf.add_to_collection
and tf.get_collection
. Without those lines, 'load' mode throws an error in the last line:
NameError: name 'myVar' is not defined
. My understanding was that by default Saver.save
saves and restores all variables in the graph, so why is it necessary to specify the name of variables that will be used in the collection? I assume this has to do with mapping Tensorflow's variable names to Python names, but what are the rules of the game here? For which variables does this need to be done?
mode = 'load' # or 'save'
if mode == 'save':
myVar = tf.Variable(7.1)
init_op = tf.global_variables_initializer()
saver0 = tf.train.Saver()
tf.add_to_collection('myVar', myVar) ### WHY NECESSARY?
with tf.Session() as sess:
sess.run(init_op)
print sess.run(myVar)
saver0.save(sess, './myModel')
if mode == 'load':
with tf.Session() as sess:
saver1 = tf.train.import_meta_graph('./myModel.meta')
saver1.restore(sess, tf.train.latest_checkpoint('./'))
myVar = tf.get_collection('myVar')[0] ### WHY NECESSARY?
print sess.run(myVar)