21

I have found 2 ways to save a model in Tensorflow: tf.train.Saver() and SavedModelBuilder. However, I can't find documentation on using the model after it being loaded the second way.

Note: I want to use SavedModelBuilder way because I train the model in Python and will use it at serving time in another language (Go), and it seems that SavedModelBuilder is the only way in that case.

This works great with tf.train.Saver() (first way):

model = tf.add(W * x, b, name="finalnode")

# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")

# load
saver.restore(sess, "/tmp/model")

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.

model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})

tf.saved_model.builder.SavedModelBuilder() is defined in the Readme but after loading the model with tf.saved_model.loader.load(sess, [], export_dir)), I can't find documentation on getting back at the nodes (see "finalnode" in the code above)

Thomas
  • 8,306
  • 8
  • 53
  • 92
  • 1
    **Note**: This function will only be available through the v1 compatibility library as `tf.compat.v1.saved_model.builder.SavedModelBuilder` or `tf.compat.v1.saved_model.Builder`. Tensorflow 2.0 will introduce a new object-based method of creating SavedModels. – Krishna Oct 11 '19 at 12:05

4 Answers4

25

What was missing was the signature

# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
        "model": tf.saved_model.signature_def_utils.predict_signature_def(
            inputs= {"x": x},
            outputs= {"finalnode": model})
        })
builder.save()

# loading
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["tag"], export_dir)
    graph = tf.get_default_graph()
    x = graph.get_tensor_by_name("x:0")
    model = graph.get_tensor_by_name("finalnode:0")
    print(sess.run(model, {x: [5, 6, 7, 8]}))
Thomas
  • 8,306
  • 8
  • 53
  • 92
6

Here's the code snippet to load and restore/predict models using the simple_save

#Save the model:
tf.saved_model.simple_save(sess, export_dir=saveModelPath,
                                   inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
                                           "isTrainingBool": isTraining},
                                   outputs={"predictedClassBatch": predClass})

Note that using simple_save sets certain default values (this can be seen at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)

Now, to restore and use the inputs/outputs dict:

from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants

with tf.Session() as sess:
  model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.

  inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
  inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)

  inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
  inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)

  isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
  isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)

  outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
  outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)

  outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})

  print("predicted classes:", outPred)

Note: the default signature_def was needed to make use of the tensor names specified in the input & output dicts.

Anurag
  • 61
  • 1
  • 2
2

Tensorflow's preferred way of building and using a model in different languages is tensorflow serving

Now in your case, you are using saver.save to save the model. This way it saves a meta file, ckpt file and some other files to save the weights and network information, steps trained etc. This is the preferred way of saving while you are training.

If you are done with training now you should freeze the graph using SavedModelBuilder from the files you save by saver.save. This frozen graph contains a pb file and contains all the network and weights.

This frozen model should be used to serve by tensorflow serving and then other languages can use the model using gRPC protocol.

The whole procedure is described in this excellent tutorial.

Sumsuddin Shojib
  • 3,583
  • 3
  • 26
  • 45
  • thanks for the answer and the link but that doesn't answer so much my question... – Thomas Aug 17 '17 at 01:20
  • 1
    The link *does* have the answer somewhere after "The last step — save the model", but this is easy to find only if you already know where to look... it could be definitely more concise but also thanks for the link and the insights – fr_andres Nov 28 '17 at 14:50
1

A code snippet that worked for me to load a pb file and inference on a single image.

The code follows the following steps: load the pb file into a GraphDef (a serialized version of a graph (used to read pb files), load GraphDef into a Graph, get input and output tensors by their names, inference on a single image.

import tensorflow as tf 
import numpy as np
import cv2

INPUT_TENSOR_NAME = 'input_tensor_name:0'
OUTPUT_TENSOR_NAME = 'output_tensor_name:0'

# Read image, get shape
# Add dimension to fit batch shape
img = cv2.imread(IMAGE_PATH)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = img.astype(float)
height, width, channels = image.shape
image = np.expand_dims(image, 0)  # Add dimension (to fit batch shape)


# Read pb file into the graph as GraphDef - Serialized version of a graph     (used to read pb files)
with tf.gfile.FastGFile(PB_PATH, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# Load GraphDef into Graph
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="")

# Get tensors (input and output) by name
input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)

# Inference on single image
with tf.Session(graph=graph) as sess:
    output_vals = sess.run(output_tensor, feed_dict={input_tensor: image})  #
Shir Portugez
  • 111
  • 1
  • 3