In order to remove a variable from the list of trainable variables, you can first access the collection through:
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
There, trainable_collection
contains a reference to the collection of trainable variables. If you pop elements from this list, doing for example trainable_collection.pop(0)
, you are going to remove the corresponding variable from the trainable variables, and thus this variable will not be trained.
Although this works with pop
, I am still struggling to find a way to correctly use remove
with the correct argument, so we don't depend on the index of the variables.
EDIT: Given that you have the name of the variables in the graph (you can obtain that by inspecting the graph protobuf or, what is easier, using Tensorboard), you can use it to loop through the list of trainable variables and then remove the variables from the trainable collection.
Example: say that I want the variables with names "batch_normalization/gamma:0"
and "batch_normalization/beta:0"
NOT to be trained, but they are already added to the TRAINABLE_VARIABLES
collection. What I can do is:
`
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
`
This will successfully remove the two variables from the collection, and they will not be trained anymore.