4

I see from the tutorial that we can do this:

for node in tf.get_default_graph().as_graph_def().node: print node

When done on an arbitrary network, we get many key value pairs. For example:

name: "conv2d_2/convolution"
op: "Conv2D"
input: "max_pooling2d/MaxPool"
input: "conv2d_1/kernel/read"
device: "/device:GPU:0"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "data_format"
  value {
    s: "NHWC"
  }
}
attr {
  key: "padding"
  value {
    s: "SAME"
  }
}
attr {
  key: "strides"
  value {
    list {
      i: 1
      i: 1
      i: 1
      i: 1
    }
  }
}
attr {
  key: "use_cudnn_on_gpu"
  value {
    b: true
  }
}

How do I access all these values and put them in Python lists? Specifically, how can we get the "strides" attribute and convert all the 1s there into [1, 1, 1, 1]?

jkschin
  • 5,776
  • 6
  • 35
  • 62

1 Answers1

4

TLDR: The code below is what you might want to use:

for n in tf.get_default_graph().as_graph_def().node:
    if 'strides' in n.attr.keys():
        print n.name, [int(a) for a in n.attr['strides'].list.i]
    if 'shape' in n.attr.keys():
        print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

The trick to doing this is to understand what protobufs are. Let's go through the tutorial mentioned above.

First of all, there's a statement:

for node in graph_def.node

Each node is a NodeDef object, defined in tensorflow/core/framework/node_def.proto. These are the fundamental building blocks of TensorFlow graphs, with each one defining a single operation along with its input connections. Here are the members of a NodeDef, and what they mean.

Note the following in node_def.proto:

  • It imports attr_value.proto.
  • There are attributes such as name, op, input, device, attr. Specifically, there's a repeated term in front of input. We can ignore this for now.

This works exactly like a Python class and we can thus call node.name, node.op, node.input, node.device, node.attr, etc.

What we would like to access now would be the contents in node.attr. If we refer to the tutorial once again, it specifies:

This is a key/value store holding all the attributes of a node. These are the permanent properties of nodes, things that don't change at runtime such as the size of filters for convolutions, or the values of constant ops. Because there can be so many different types of attribute values, from strings, to ints, to arrays of tensor values, there's a separate protobuf file defining the data structure that holds them, in tensorflow/core/framework/attr_value.proto.

Each attribute has a unique name string, and the expected attributes are listed when the operation is defined. If an attribute isn't present in a node, but it has a default listed in the operation definition, that default is used when the graph is created.

You can access all of these members by calling node.name, node.op, etc. in Python. The list of nodes stored in the GraphDef is a full definition of the model architecture.

Since this is a key-value store we can call n.attr.keys() to see a list of keys this attribute has. We can go further to call perhaps n.attr['strides'] to access the strides, if such a key is available. When we try to print this, we get the following:

list {
  i: 1
  i: 2
  i: 2
  i: 1
}

And this is where it starts to get confusing because we might try to do list(n.attr['strides']) or something of this sort. If we look at attr_value.proto, we can understand what's going on. We see that it's oneof value and in this case it's a ListValue list, so we can call n.attr['strides'].list. And if we print this, we get the following:

i: 1
i: 1
i: 1
i: 1

We might next try to do this: [a for a in n.attr['strides'].list] or [a.i for a in n.attr['strides'].list]. However, nothing works. This is where repeated is an important term to understand. It basically means that there's an int64 list and you have to access it with the i attribute. Doing [int(a) for a in n.attr['strides'].list.i] then gives us what we want, a Python list that we can use:

[1, 1, 1, 1]
jkschin
  • 5,776
  • 6
  • 35
  • 62