2

I have been investigating importing the checkpoint of a pretrained model in tensorflow. The purpose of this is so that I can examine its structure, and use it for image classification.

Specifically, the mobilenet model found here. I couldn't find any reasonable way to import the model from the various *.ckpt.* files, and with some forum sniffing I found a gist written by Github user StanislawAntol which purported to convert said files into a frozen model, ProtoBuf (.pb) file. The gist is here

Running the script gives me a bunch of .pb files, which I would hope I could work with. Indeed, this SO question seemed to answer my prayers.

I have been trying variants of the following code, but no avail. Any object that was returned by tf.import_graph_def seemed to be of type None.

import tensorflow as tf
from tensorflow.python.platform import gfile

model_filename = LOCATION_OF_PB_FILE

with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name='')

print(g_in)

Is there something I'm missing here? Is the entire conversion to .pb erroneous?

Prunus Persica
  • 1,173
  • 9
  • 27

1 Answers1

1

tf.import_graph_def does not return the graph, it populates the "default graph" in the scope. See documentation for tf.import_graph_def for details on the return value.

In your case, you can inspect the graph using tf.get_default_graph(). For example:

with gfile.FastGFile(model_filename, 'rb') as f:
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')

g = tf.get_default_graph()
print(len(g.get_operations()))

See documentation for tf.Graph for more details on the notion of a "default graph" and scoping.

Hope that helps.

ash
  • 6,681
  • 3
  • 18
  • 30