108

From what I've gathered so far, there are several different ways of dumping a TensorFlow graph into a file and then loading it into another program, but I haven't been able to find clear examples/information on how they work. What I already know is this:

  1. Save the model's variables into a checkpoint file (.ckpt) using a tf.train.Saver() and restore them later (source)
  2. Save a model into a .pb file and load it back in using tf.train.write_graph() and tf.import_graph_def() (source)
  3. Load in a model from a .pb file, retrain it, and dump it into a new .pb file using Bazel (source)
  4. Freeze the graph to save the graph and weights together (source)
  5. Use as_graph_def() to save the model, and for weights/variables, map them into constants (source)

However, I haven't been able to clear up several questions regarding these different methods:

  1. Regarding checkpoint files, do they only save the trained weights of a model? Could checkpoint files be loaded into a new program, and be used to run the model, or do they simply serve as ways to save the weights in a model at a certain time/stage?
  2. Regarding tf.train.write_graph(), are the weights/variables saved as well?
  3. Regarding Bazel, can it only save into/load from .pb files for retraining? Is there a simple Bazel command just to dump a graph into a .pb?
  4. Regarding freezing, can a frozen graph be loaded in using tf.import_graph_def()?
  5. The Android demo for TensorFlow loads in Google's Inception model from a .pb file. If I wanted to substitute my own .pb file, how would I go about doing that? Would I need to change any native code/methods?
  6. In general, what exactly is the difference between all these methods? Or more broadly, what is the difference between as_graph_def()/.ckpt/.pb?

In short, what I'm looking for is a method to save both a graph (as in, the various operations and such) and its weights/variables into a file, which can then be used to load the graph and weights into another program, for use (not necessarily continuing/retraining).

Documentation about this topic isn't very straightforward, so any answers/information would be greatly appreciated.

Community
  • 1
  • 1
Technicolor
  • 1,589
  • 2
  • 17
  • 31
  • 2
    The newest/most complete API is meta graph, that'll gives you a way to save all three at once -- 1) graph 2) parameter values 3) collections: https://www.tensorflow.org/versions/r0.10/how_tos/meta_graph/index.html – Yaroslav Bulatov Aug 15 '16 at 04:56

2 Answers2

87

There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:

  1. The checkpoint files (produced e.g. by calling saver.save() on a tf.train.Saver object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights. Note that calling saver.save() also produces a file containing a MetaGraphDef, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorial for more details.

  2. tf.train.write_graph() only writes the graph structure; not the weights.

  3. Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)

  4. A frozen graph can be loaded using tf.import_graph_def(). In this case, the weights are (typically) embedded in the graph, so you don't need to load a separate checkpoint.

  5. The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the inputName and outputName strings that are passed to TensorFlowClassifier.initializeTensorFlow().

  6. The GraphDef is the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as the MetaGraphDef libraries, Keras, and skflow build on these mechanisms to provide more convenient ways to save and restore an entire model.

mrry
  • 125,488
  • 26
  • 399
  • 400
  • Does this mean that the [C++ API documentation](https://www.tensorflow.org/versions/r0.11/api_docs/cc/index.html) lies, when it says that you can load the graph saved with `tf.train.write_graph()` and then execute it? – mnicky Oct 26 '16 at 09:18
  • 2
    The C++ API documentation does not lie, but it is missing a few details. The most important detail is that, in addition to the `GraphDef` saved by `tf.train.write_graph()`, you also need to remember the names of the tensors that you want to feed and fetch when executing the graph (item 5 above). – mrry Oct 26 '16 at 14:43
  • @mrry: I tried to use the tensorflows DeepDream example. but it seems it needs pretrained models in pb format! I ran the Cifar10 example, but it only creates checkpoints! I couldnt find any pb files or whatsoever! how can I convert my checkpoints to the pb format which deepdream example uses? – Hossein Jan 22 '17 at 12:48
  • 2
    @Coderx7 I really think you cannot convert a .ckpt to a .pb since the checkpoint only contains the weights and variables and does not know anything about the structure of the graph – David Ortiz Apr 24 '17 at 10:34
  • @mrry Is the.pb output file considered a frozen file? I can't seem to extract the weight despite using `import_graph_def` https://stackoverflow.com/questions/46696859/tf-graphkeys-trainable-variables-on-output-graph-pb-resulting-in-empty-list – Moondra Oct 12 '17 at 00:18
  • 2
    is there a simple code to load a .pb file and then run it ? – Kong Jul 12 '18 at 02:54
  • The tutorial link is dead. – ComputerScientist Aug 17 '18 at 22:22
  • Updated all the tensorflow.org links to the current pages. – mrry Aug 20 '18 at 17:26
1

You can try the following code:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
valdeci
  • 13,962
  • 6
  • 55
  • 80
Srihari Humbarwadi
  • 2,532
  • 1
  • 10
  • 28