TLDR: The code below is what you might want to use:
for n in tf.get_default_graph().as_graph_def().node:
if 'strides' in n.attr.keys():
print n.name, [int(a) for a in n.attr['strides'].list.i]
if 'shape' in n.attr.keys():
print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]
The trick to doing this is to understand what protobufs are. Let's go through the tutorial mentioned above.
First of all, there's a statement:
for node in graph_def.node
Each node is a NodeDef object, defined in
tensorflow/core/framework/node_def.proto. These are the fundamental
building blocks of TensorFlow graphs, with each one defining a single
operation along with its input connections. Here are the members of a
NodeDef, and what they mean.
Note the following in node_def.proto:
- It imports attr_value.proto.
- There are attributes such as name, op, input, device, attr. Specifically, there's a
repeated
term in front of input. We can ignore this for now.
This works exactly like a Python class and we can thus call node.name, node.op, node.input, node.device, node.attr, etc.
What we would like to access now would be the contents in node.attr. If we refer to the tutorial once again, it specifies:
This is a key/value store holding all the attributes of a node. These
are the permanent properties of nodes, things that don't change at
runtime such as the size of filters for convolutions, or the values of
constant ops. Because there can be so many different types of
attribute values, from strings, to ints, to arrays of tensor values,
there's a separate protobuf file defining the data structure that
holds them, in tensorflow/core/framework/attr_value.proto.
Each attribute has a unique name string, and the expected attributes
are listed when the operation is defined. If an attribute isn't
present in a node, but it has a default listed in the operation
definition, that default is used when the graph is created.
You can access all of these members by calling node.name, node.op,
etc. in Python. The list of nodes stored in the GraphDef is a full
definition of the model architecture.
Since this is a key-value store we can call n.attr.keys()
to see a list of keys this attribute has. We can go further to call perhaps n.attr['strides']
to access the strides, if such a key is available. When we try to print this, we get the following:
list {
i: 1
i: 2
i: 2
i: 1
}
And this is where it starts to get confusing because we might try to do list(n.attr['strides'])
or something of this sort. If we look at attr_value.proto, we can understand what's going on. We see that it's oneof value
and in this case it's a ListValue list
, so we can call n.attr['strides'].list
. And if we print this, we get the following:
i: 1
i: 1
i: 1
i: 1
We might next try to do this: [a for a in n.attr['strides'].list]
or [a.i for a in n.attr['strides'].list]
. However, nothing works. This is where repeated
is an important term to understand. It basically means that there's an int64 list and you have to access it with the i
attribute. Doing [int(a) for a in n.attr['strides'].list.i]
then gives us what we want, a Python list that we can use:
[1, 1, 1, 1]