I found the following code snippet to visualize a model which was saved to a *.pb
file:
model_filename ='saved_model.pb'
with tf.Session() as sess:
with gfile.FastGFile(path_to_model_pb, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='.'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
Now I am struggling to create the saved_model.pb
in the first place. If my session.run looks like this:
_, cr_loss = sess.run([train_op,cross_entropy_loss],
feed_dict={input_image: images,
correct_label: gt_images,
keep_prob: KEEP_PROB,
learning_rate: LEARNING_RATE}
)
How do I save the graph contained in train_op
to saved_model.pb
?