38

I created a trainable variable in a scope. Later, I entered the same scope, set the scope to reuse_variables, and used get_variable to retrieve the same variable. However, I cannot set the variable's trainable property to False. My get_variable line is like:

weight_var = tf.get_variable('weights', trainable = False)

But the variable 'weights' is still in the output of tf.trainable_variables.

Can I set a shared variable's trainable flag to False by using get_variable?

The reason I want to do this is that I'm trying to reuse the low-level filters pre-trained from VGG net in my model, and I want to build the graph like before, retrieve the weights variable, and assign VGG filter values to the weight variable, and then keep them fixed during the following training step.

nbro
  • 15,395
  • 32
  • 113
  • 196
Wei Liu
  • 1,004
  • 1
  • 10
  • 17
  • The `var_list` argument in the [`minimize()`](https://www.tensorflow.org/versions/r0.8/api_docs/python/train.html#Optimizer) function is the standard place to specify training on only some variables. – user728291 May 20 '16 at 00:12

4 Answers4

31

After looking at the documentation and the code, I was not able to find a way to remove a Variable from the TRAINABLE_VARIABLES.

Here is what happens:

  • The first time tf.get_variable('weights', trainable=True) is called, the variable is added to the list of TRAINABLE_VARIABLES.
  • The second time you call tf.get_variable('weights', trainable=False), you get the same variable but the argument trainable=False has no effect as the variable is already present in the list of TRAINABLE_VARIABLES (and there is no way to remove it from there)

First solution

When calling the minimize method of the optimizer (see doc.), you can pass a var_list=[...] as argument with the variables you want to optimizer.

For instance, if you want to freeze all the layers of VGG except the last two, you can pass the weights of the last two layers in var_list.

Second solution

You can use a tf.train.Saver() to save variables and restore them later (see this tutorial).

  • First you train your entire VGG model with all trainable variables. You save them in a checkpoint file by calling saver.save(sess, "/path/to/dir/model.ckpt").
  • Then (in another file) you train the second version with non trainable variables. You load the variables previously stored with saver.restore(sess, "/path/to/dir/model.ckpt").

Optionally, you can decide to save only some of the variables in your checkpoint file. See the doc for more info.

Olivier Moindrot
  • 27,908
  • 11
  • 92
  • 91
  • Thanks. I tried the same thing to find if I can remove a variable from the collection of ```TRAINABLE_VARIABLES```, but could not. Looks like define a list of trainable list is the best for me. – Wei Liu May 20 '16 at 18:13
  • 4
    Hold on, I just found that ```get_collection_ref()``` returns a reference of the trainable_variables collection, which I should be able to change and remove some entries. I haven't tested it yet. Anyway that's less important. I can always filter the trainable variables I got from ```get_collection()``` and send it to optimizer. – Wei Liu May 20 '16 at 20:23
  • Does change trainable collection by `get_collection_ref()` have any side effect? – shellhue Sep 10 '18 at 15:43
  • 1
    @Olivier It is **NOT** true that you cannot remove a trainable variable from the list of trainables. You can do `trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)` to get a reference for the collection of trainable variables, which is a python list, and then use `pop` with the right index to remove the variable from there. I tested it and it prevents the variable from being trained. – Elisio Quintino Oct 23 '18 at 07:41
  • I'm adding an answer with the method to remove variables from the TRAINABLE_VARIABLES collection – Elisio Quintino Oct 23 '18 at 12:11
  • remove from trainable_variables: just mess with `tf.get_default_graph()._collections['trainable_variables']`. btw it's useless by the time apply_gradient() is called – skywalkerytx Mar 12 '19 at 13:57
14

When you want to train or optimize only certain layers of a pre-trained network, this is what you need to know.

TensorFlow's minimize method takes an optional argument var_list, a list of variables to be adjusted through back-propagation.

If you don't specify var_list, any TF variable in the graph could be adjusted by the optimizer. When you specify some variables in var_list, TF holds all other variables constant.

Here's an example of a script which jonbruner and his collaborator have used.

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

This finds all the variables they defined earlier that have "g_" in the variable name, puts them into a list, and runs the ADAM optimizer on them.

You can find the related answers here on Quora

rocksyne
  • 1,264
  • 15
  • 17
  • Sorry I know this is an old thread but what does g_loss refer to? Trying to replicate your solution – A_Murphy Apr 01 '22 at 08:44
7

In order to remove a variable from the list of trainable variables, you can first access the collection through: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) There, trainable_collection contains a reference to the collection of trainable variables. If you pop elements from this list, doing for example trainable_collection.pop(0), you are going to remove the corresponding variable from the trainable variables, and thus this variable will not be trained.

Although this works with pop, I am still struggling to find a way to correctly use remove with the correct argument, so we don't depend on the index of the variables.

EDIT: Given that you have the name of the variables in the graph (you can obtain that by inspecting the graph protobuf or, what is easier, using Tensorboard), you can use it to loop through the list of trainable variables and then remove the variables from the trainable collection. Example: say that I want the variables with names "batch_normalization/gamma:0" and "batch_normalization/beta:0" NOT to be trained, but they are already added to the TRAINABLE_VARIABLES collection. What I can do is: `

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

` This will successfully remove the two variables from the collection, and they will not be trained anymore.

Elisio Quintino
  • 475
  • 6
  • 18
  • 1
    I tried your method, but unfortunately it does NOT work, at least for TF version 1.6.0. To be precise, it "wins the battle but loses the war"! It indeed nukes the trainable variables from being listed as trainable [as shown by calling `tf.trainable_variables()` ]... but the `minimize()` method carries on as if nothing had happened - i.e. those variables are still being trained! My guess is that when minimize() is first called, it takes a "snapshot" of what variables to train, and then it keeps using them regardless of later changes in the graph collection GraphKeys.TRAINABLE_VARIABLES – Julian - BrainAnnex.org Jan 02 '19 at 02:53
  • Thank you for noting this. I worked always with versions above 1.10. Do you have some code exemplifying this behavior so I can check with later versions and also try to understand why it doesn't work? – Elisio Quintino Jan 25 '19 at 11:12
0

You can use tf.get_collection_ref to get the reference of collection rather than tf.get_collection

Yuki
  • 9
  • 1
  • 4