3

I want to count parameters in a tensorflow model. It is similar to the existing question as follows.

How to count total number of trainable parameters in a tensorflow model?

But if the model is defined with a graph loaded from .pb file, all the proposed answers don't work. Basically I loaded the graph with the following function.

def load_graph(model_file):

  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())

  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

One example is to load a frozen_graph.pb file for retraining purpose in tensorflow-for-poets-2.

https://github.com/googlecodelabs/tensorflow-for-poets-2

Yanjun
  • 41
  • 6
  • I don't really see how the answers in the other question don't work. Once you have the graph, you just need to fetch the trainable variables of that specific graph. What have you tried after calling that function? Can you provide a sample .pbtxt file that reproduces the problem? – E_net4 May 03 '18 at 17:22

1 Answers1

1

To my understanding, a GraphDef doesn't have enough information to describe Variables. As explained here, you will need MetaGraph, which contain both GraphDef and CollectionDef which is a map that can describe Variables. So the following code should give us the correct trainable variable count.

Export MetaGraph:

import tensorflow as tf

a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, r'.\test')

Import MetaGraph and count total number of trainable parameters.

import tensorflow as tf

saver = tf.train.import_meta_graph('test.meta')

with tf.Session() as sess:
    saver.restore(sess, 'test')

total_parameters = 0
for variable in tf.trainable_variables():
    total_parameters += 1
print(total_parameters)
Y. Luo
  • 5,622
  • 1
  • 18
  • 25
  • Does it mean that .pb file doesn't contain any trainable_variables? Thanks a lot. – Yanjun May 03 '18 at 22:38
  • @Yanjun I think nodes are all there. But you can't tell which one is trainable_variables. Or they are not in `tf.trainable_variables()` after being loaded. – Y. Luo May 03 '18 at 22:49