1

I load a graph and want to access the weights defined in the graph as h1, h2, h3.

I can easily do this by hand for every weight tensor h by doing:

sess = tf.Session()
graph = tf.get_default_graph()
h1 = sess.graph.get_tensor_by_name("h1:0")
h2 = sess.graph.get_tensor_by_name("h2:0")

I don't like this approach since it is going to be ugly for a large graph. I would prefer something like a loop over all weight tensors which puts them into a list.

I did find two other questions (here and here) on Stack Overflow but they did not help me with this problem.

I tried the following approach which has two problems:

num_weight_tensors = 3
weights = []
for w in range(num_weight_tensors):
    weights.append(sess.graph.get_tensor_by_name("h1:0"))
print(weights)

First problem: I have to define the number of weight tensors in the graph which make the code inflexible. Second problem: the argument of get_tensor_by_name() is static.

Is there a way to just get all tensors and put them into a list?

jodag
  • 19,885
  • 5
  • 47
  • 66
Gilfoyle
  • 3,282
  • 3
  • 47
  • 83

2 Answers2

1

You can call the tf.trainable_variables() if you are only concerned about weights you can optimize. It returns a list of all variables with trainable parameter set to True .

tf.reset_default_graph()

# These can be optimized
for i in range(5):
    tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))

# These cannot be optimized
for i in range(5):
    tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="n{}".format(i), trainable=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = tf.get_default_graph()
    for t_var in tf.trainable_variables():
        print(t_var)

Prints:

<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>

On the other hand tf.global_variables() returns a list of all variables:

for g_var in tf.global_variables():
    print(g_var)
<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n4:0' shape=(32, 32) dtype=float32_ref>

UPDATE

To have more control over the Variables that you'd like to receive there are several way how to filten them. One way is what openmark suggested. In this case you can filter them based on the variable scope prefix.

However, if this is not enough, for example if you wish to access several groups simultaneously there are other ways. You could simply filter them by name, that is:

for g_var in tf.global_variables():
  if g_var.name.beginswith('h'):
    print(g_var) 

However, you have to be aware of the naming convention of tensorflow Variables. That is :0 postfix for example, variable scope prefix and more.

Second way, less involved, is to create your own collections. For example if I am interested in variables that ends with a number divisible by 2 and somewhere else in the code I am interested in variables whose name ends with a number divisible by 4 I could do something like this:

# These can be optimized
for i in range(5):
    h_var = tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))
    if i % 2 == 0:
      tf.add_to_collection('vars_divisible_by_2', h_var)
    if i % 4 == 0:
      tf.add_to_collection('vars_divisible_by_4', h_var)

and then I can simply call tf.get_collection() function:

tf.get_collection('vars_divisible_by_2)
[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]

or

tf.get_collection('vars_divisible_by_4'):
[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]
Aechlys
  • 1,286
  • 7
  • 16
  • When I get the list using `tf.global_variavles()` or `tf.trainable_variables()`, how do I filter the list for variables which start with `h`? Finally I want a list like `all_weights = [h1, h2, h3, ... ,hn]` where `h1` - `hn` are of type `tf.Variable`. – Gilfoyle May 04 '18 at 10:18
  • I've updated my original answer. You can either use openmark's solution with variable scopes; you can filter them by name yourself (which can, however, get quite involved very fast); or you can simply create your own collections. – Aechlys May 04 '18 at 13:20
1

You can try tf.get_collection():

tf.get_collection(
key,
scope=None)

It returns a list of items in a collection specified by key and scope. key is a key from standard graph collections tf.GraphKeys, for instance, tf.GraphKeys.TRAINABLE_VARIABLES specifies a subset of variables that are trained by optimizer, while tf.GraphKeys.GLOBAL_VARIABLES specifies a list of global variables including non-trainable ones. Check the link above for a list of available key types. You can also specify the scope parameter to filter the resulting list to return only the items from the specific name scope, here is a small example:

with tf.name_scope("aaa"):
    aaa1 = tf.Variable(tf.zeros(shape=(1,2,3)), name="aaa1")


with tf.name_scope("bbb"):
    bbb1 = tf.Variable(tf.zeros(shape=(4,5,6)), name="bbb1", trainable=False)

for item in  tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES):
    print(item)
# >>> <tf.Variable 'aaa/aaa1:0' shape=(1, 2, 3) dtype=float32_ref>

for item in  tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
# >>> <tf.Variable 'aaa/aaa1:0' shape=(1, 2, 3) dtype=float32_ref>
# >>> <tf.Variable 'bbb/bbb1:0' shape=(4, 5, 6) dtype=float32_ref>

for item in  tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope="bbb"):
    print(item)
# >>> <tf.Variable 'bbb/bbb1:0' shape=(4, 5, 6) dtype=float32_ref>
abhuse
  • 1,086
  • 1
  • 11
  • 14