47

I am looking at Google's example on how to deploy and use a pre-trained Tensorflow graph (model) on Android. This example uses a .pb file at:

https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

which is a link to a file that downloads automatically.

The example shows how to load the .pb file to a Tensorflow session and use it to perform classification, but it doesn't seem to mention how to generate such a .pb file, after a graph is trained (e.g., in Python).

Are there any examples on how to do that?

nbro
  • 15,395
  • 32
  • 113
  • 196
Shanqing Cai
  • 3,756
  • 3
  • 23
  • 36

6 Answers6

39

EDIT: The freeze_graph.py script, which is part of the TensorFlow repository, now serves as a tool that generates a protocol buffer representing a "frozen" trained model, from an existing TensorFlow GraphDef and a saved checkpoint. It uses the same steps as described below, but it much easier to use.


Currently the process isn't very well documented (and subject to refinement), but the approximate steps are as follows:

  1. Build and train your model as a tf.Graph called g_1.
  2. Fetch the final values of each of the variables and store them as numpy arrays (using Session.run()).
  3. In a new tf.Graph called g_2, create tf.constant() tensors for each of the variables, using the value of the corresponding numpy array fetched in step 2.
  4. Use tf.import_graph_def() to copy nodes from g_1 into g_2, and use the input_map argument to replace each variable in g_1 with the corresponding tf.constant() tensors created in step 3. You may also want to use input_map to specify a new input tensor (e.g. replacing an input pipeline with a tf.placeholder()). Use the return_elements argument to specify the name of the predicted output tensor.

  5. Call g_2.as_graph_def() to get a protocol buffer representation of the graph.

(NOTE: The generated graph will have extra nodes in the graph for training. Although it is not part of the public API, you may wish to use the internal graph_util.extract_sub_graph() function to strip these nodes from the graph.)

mrry
  • 125,488
  • 26
  • 399
  • 400
  • 3
    tensorflow.python.framework.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None) does a very similar thing in a single function call. Also, it calls extract_sub_graph(input_graph_def, output_node_names) internally – MZHm Jan 29 '17 at 20:11
20

Alternatively to my previous answer using freeze_graph(), which is only good if you call it as a script, there is a very nice function that will do all the heavy lifting for you and is suitable to be called from your normal model training code.

convert_variables_to_constants() does two things:

  • It freezes the weights by replacing variables with constants
  • It removes nodes which are not related to feedforward prediction

Assuming sess is your tf.Session() and "output" is the name of your prediction node, the following code will serialize your minimal graph both into textual and binary protobuf.


from tensorflow.python.framework.graph_util import convert_variables_to_constants

minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])

tf.train.write_graph(minimal_graph, '.', 'minimal_graph.proto', as_text=False)
tf.train.write_graph(minimal_graph, '.', 'minimal_graph.txt', as_text=True)
mirosval
  • 6,671
  • 3
  • 32
  • 46
  • 1
    Just a little note: previously, `graph_util` was in `tensorflow.python.client`, so it may be useful to use try-except to load the right package and support multiple versions. – AkiRoss Sep 21 '16 at 08:28
  • I try using this with keras models and get an error File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/graph_util.py", line 158, in extract_sub_graph assert d in name_to_node_map, "%s is not in graph" % d AssertionError: output_node is not in graph The relevant lines (shortened are: model = Sequential() model.add(Dense(12, ...) model.add(Dense(8, .... ) model.add(Dense(1, init='uniform', activation='sigmoid', name='output_node')) ....... minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ['output_node']) Any idea why? – Nir Jan 06 '17 at 20:31
  • Likely your keras output node is named differently. – mirosval Jan 08 '17 at 14:15
  • If if have a pregenerated graph (inceptionv3 built by the slim library) how can I find out the names of the input and output tensors? It doesn't look like they're built in order, or perhaps it is up to me to add a tensor to the beginning to input a jpeg into the graph... – user2647513 Aug 11 '17 at 23:22
  • I think convert_variables_to_constants is a new API introduced that was not available in earlier versions at the time of the question which is great. Now, to complete the answer, can you please provide the second part of how to load that in the cpp, or we don't need to and we just can use the model? – Mostafa Hassan Nov 03 '17 at 05:19
5

I could not figure out how to implement the method described by mrry. But here how I solved it. I'm not sure if that is the best way of solving the problem but at least it solves it.

As write_graph can also store the values of the constants, I added the following code to the python just before writing the graph with write_graph function:

for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    tf.assign(v, vc, name="assign_variables")

This creates constants that store variables' values after being trained and then create tensors "assign_variables" to assign them to the variables. Now, when you call write_graph, it will store the variables' values in the file in form of constants.

The only remaining part is to call these tensors "assign_variables" in the c code to make sure that your variables are assigned with the constants values that are stored in the file. Here is a one way to do it:

      Status status = NewSession(SessionOptions(), &session);
      std::vector<tensorflow::Tensor> outputs;
      char name[100];
      for(int i = 0;status.ok(); i++) {
        if (i==0)
            sprintf(name, "assign_variables");
        else
            sprintf(name, "assign_variables_%d", i);

        status = session->Run({}, {name}, {}, &outputs);
      }
Mostafa Hassan
  • 169
  • 1
  • 8
  • Do you have a full example of this? When I try and load the saved graph in python (I wanted to test it before fighting with the C++ api) I get a decoding error. `google.protobuf.message.DecodeError: Tag had invalid wire type` – James Sweet Jan 26 '16 at 17:16
  • That depends on how you read the saved graph. You can read the saved graph from python by using `with gfile.FastGFile("filename.pb", 'rb') as f:` `graph_def = tf.GraphDef()` `graph_def.ParseFromString(f.read())` `tf.import_graph_def(graph_def, name='')` then you can do the same way of initializing the variables by calling: `session.run("assign_variables")` – Mostafa Hassan Jan 26 '16 at 21:01
4

Here's another take on @Mostafa's answer. A somewhat cleaner way to run the tf.assign ops is to store them in a tf.group. Here's my Python code:

  ops = []
  for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    ops.append(tf.assign(v, vc));
  tf.group(*ops, name="assign_trained_variables")

And in C++:

  std::vector<tensorflow::Tensor> tmp;
  status = session.Run({}, {}, { "assign_trained_variables" }, &tmp);
  if (!status.ok()) {
    // Handle error
  }

This way you have only one named op to run on the C++ side, so you don't have to mess around with iterating over nodes.

Kris Giesing
  • 334
  • 1
  • 15
1

Just found this post and it was very useful thanks! I'm also going with @Mostafa's method, though my C++ code is a bit different:

    std::vector<string> names;
    int node_count = graph.node_size();
    cout << node_count << " nodes in graph" << endl;

    // iterate all nodes
    for(int i=0; i<node_count; i++) {
        auto n = graph.node(i);
        cout << i << ":" << n.name() << endl;

        // if name contains "var_hack", add to vector
        if(n.name().find("var_hack") != std::string::npos) {
            names.push_back(n.name());
            cout << "......bang" << endl;
        }
    }
    session.Run({}, names, {}, &outputs);

NB I use "var_hack" as my variable name in python

memo
  • 3,554
  • 4
  • 31
  • 36
1

I found a freeze_graph() function in the Tensorflow codebase that might be helpful when doing this. From what I understand it swaps variables with constants before serializing the GraphDef and so when you then load this graph from C++ it has no variables that need to be set anymore, and you can directly use it for predictions.

There is also a test for it and some description in the Guide.

This seems like the cleanest option here.

mirosval
  • 6,671
  • 3
  • 32
  • 46