4

I found the following code snippet to visualize a model which was saved to a *.pb file:

model_filename ='saved_model.pb'
with tf.Session() as sess:
    with gfile.FastGFile(path_to_model_pb, 'rb') as f:
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
        LOGDIR='.'
        train_writer = tf.summary.FileWriter(LOGDIR)
        train_writer.add_graph(sess.graph)

Now I am struggling to create the saved_model.pb in the first place. If my session.run looks like this:

  _, cr_loss = sess.run([train_op,cross_entropy_loss],
                         feed_dict={input_image: images,
                                    correct_label: gt_images,
                                    keep_prob:  KEEP_PROB,
                                    learning_rate: LEARNING_RATE}
                        )

How do I save the graph contained in train_op to saved_model.pb ?

Oblomov
  • 8,953
  • 22
  • 60
  • 106

2 Answers2

7

The easiest way is to use tf.train.write_graph. Usually, you just need to do something like:

tf.train.write_graph(my_graph, path_to_model_pb,
                     'saved_model.pb', as_text=False)

my_graph can be tf.get_default_graph() if you are using the default graph or any other tf.Graph (or tf.GraphDef) object.

Note that this saves the graph definition, which is ok to visualize it, but if you have variables their values will not be saved there unless you freeze the graph first (since those are only in the session object, not the graph itself).

jdehesa
  • 58,456
  • 7
  • 77
  • 121
  • Thanks. But what do I have to do so that tf.get_default_graph() really returns the graph I just built? Please see the updated original question for some specification. – Oblomov Sep 01 '17 at 13:25
  • @user1934212 I've updated the answer. If your graph was created somewhere else (e.g. the previous function in the call stack), put the code inside a context like `with my_graph.as_default():` and then [`tf.get_default_graph()`](https://www.tensorflow.org/api_docs/python/tf/get_default_graph) will return the right graph. – jdehesa Sep 01 '17 at 13:53
  • If the graph was created somehwere else, would `train_op` from my original question be identical to `my_graph` in `with my_graph.as_default()` and if not, how would I get from `train_op` to `my_graph` ? – Oblomov Sep 02 '17 at 06:10
  • @user1934212 Well, I can't tell what is the value of `train_op` from your code, but assuming it is a [`tf.Operation`](https://www.tensorflow.org/api_docs/python/tf/Operation) or a [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor), you can then get the graph simply as `train_op.graph`. – jdehesa Sep 02 '17 at 10:53
  • What is `path_to_model_pb`? Why does this parameter exist along with the `'saved_model.pb'` one? – SomethingSomething Jan 07 '20 at 08:47
  • @SomethingSomething Look at [the docs](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/io/write_graph) of the function, it is just the directory. If you see [the source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/python/framework/graph_io.py), you'll see it's a very simple wrapper to another function and it just joins them both, not sure why they designed it like that... – jdehesa Jan 07 '20 at 11:54
2

I will cover this problem in steps:

To visualize variables like weights, biases use tf.summary.histogram

weights = {
    'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
    'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
tf.summary.histogram("weight1", weights['h1'])
tf.summary.histogram("weight2", weights['h2'])
tf.summary.histogram("weight3", weights['out'])
biases = {
    'b1': tf.Variable(tf.random_normal([n_hidden_1])),
    'b2': tf.Variable(tf.random_normal([n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_classes]))
}
tf.summary.histogram("bias1", biases['b1'])
tf.summary.histogram("bias2", biases['b2'])
tf.summary.histogram("bias3", biases['out'])
cost = tf.sqrt(tf.reduce_mean(tf.squared_difference(pred, y)))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
tf.summary.scalar('rmse', cost)

Then while training include following code.

summaries = tf.summary.merge_all()
with tf.Session() as sess:
    sess.run(init)
    # Get data
    writer = tf.summary.FileWriter("histogram_example", sess.graph)
    # Training cycle
            # Run optimization op (backprop) and cost op (to get loss value)
            summ, p, _, c = sess.run([summ, pred, optimizer, cost], feed_dict={x: batch_x,
                                                          y: batch_y,})
            writer.add_summary(summ, global_step=epoch*total_batch+i)
Tushar Gupta
  • 1,603
  • 13
  • 20