39

When using the TensorFlow Python API, I created a variable (without specifying its name in the constructor), and its name property had the value "Variable_23:0". When I try to select this variable using tf.get_variable("Variable23"), a new variable called "Variable_23_1:0" is created instead. How do I correctly select "Variable_23" instead of creating a new one?

What I want to do is select the variable by name, and reinitialize it so I can finetune weights.

mrry
  • 125,488
  • 26
  • 399
  • 400
user3528623
  • 421
  • 1
  • 4
  • 7

4 Answers4

40

The get_variable() function creates a new variable or returns one created earlier by get_variable(). It won't return a variable created using tf.Variable(). Here's a quick example:

>>> with tf.variable_scope("foo"):
...   bar1 = tf.get_variable("bar", (2,3)) # create
... 
>>> with tf.variable_scope("foo", reuse=True):
...   bar2 = tf.get_variable("bar")  # reuse
... 

>>> with tf.variable_scope("", reuse=True): # root variable scope
...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
... 
>>> (bar1 is bar2) and (bar2 is bar3)
True

If you did not create the variable using tf.get_variable(), you have a couple options. First, you can use tf.global_variables() (as @mrry suggests):

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True

Or you can use tf.get_collection() like so:

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True

Edit

You can also use get_tensor_by_name():

>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
bar2 in value.

Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0. If the operation has multiple outputs, they have the same name as the operation plus :0, :1, :2, and so on.

BugKiller
  • 1,470
  • 1
  • 13
  • 22
MiniQuark
  • 46,633
  • 36
  • 147
  • 183
37

The easiest way to get a variable by name is to search for it in the tf.global_variables() collection:

var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]

This works well for ad hoc reuse of existing variables. A more structured approach—for when you want to share variables between multiple parts of a model—is covered in the Sharing Variables tutorial.

BlueSun
  • 3,541
  • 1
  • 18
  • 37
mrry
  • 125,488
  • 26
  • 399
  • 400
  • can you provide an example with the sharing variable thing? It keeps asking me to reuse and I understand what that means, but I can't get tensorflow to work. – Charlie Parker Jul 25 '16 at 16:51
  • 3
    UPDATE:WARNING:tensorflow:From :1 in .: all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02. Instructions for updating: Please use tf.global_variables instead. – MartianMartian Feb 16 '17 at 06:41
  • Hi, thanks. I have a different situation. Could you please tell me can I use `get_tensor_by_name` or `get variable by name` to get something defined by `tf.layers.dense`, e.g., `means` in this code sample [here](https://gist.github.com/Erichliu00/1ce345e548b31cf1f2a6efed34ba9dec). – ytutow Aug 26 '17 at 02:45
  • Is their a better way not to iterate all created ops and use a key instead? – MeadowMuffins Jul 07 '19 at 11:20
  • 1
    Link is dead. too old ? – Thomas Jul 08 '20 at 07:19
0

If you want to get any stored variables from a model, usetf.train.load_variable("model_folder_name","Variable name")

0

Based on @mrry 's answer, I think it would be better to create and use the following function, since there's also local variables, and other variables that are not in global variables (they are in different collections):

def get_var_by_name(query_name, var_list):
    """
    Get Variable by name

    e.g.
    local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
    the_var = get_var_by_name(local_vars, 'accuracy/total:0')
    """
    target_var = None
    for var in var_list:
        if var.name==query_name:
            target_var = var
            break
    return target_var
ChrisZZ
  • 1,521
  • 2
  • 17
  • 24