Note: this answer and the OP's answer are complementary to each other. Read OP's answer first.
I've spent 4 hours today on this issue. This is one of the places where the ugliness of TensorFlow unfolds (and that's why you should use PyTorch if graph manipulation is your thing).
The crucial point here is: the tf.Variable
is NOT a graph element (more on it here) but a wrapper around 3 ops: the Assign
op, the Read
op, and the VariableV2
op which is essentially a ref tensor
(more on it here). So, it is something you need to call explicitly in the TensorFlow Framework.
If we look closely at the graph_editor
's code, especially the transform module, we can see that it operates only on the tf.Graph
, not touching anything from the TensorFlow Framework. So, the graph_editor.copy
(and similar) methods does not touch tf.Variable
objects at all. It only copies the tensors and ops that are building blocks of tf.Variable
.
Okay, then how do we solve this problem?
Suppose you have the following variable:
var = tf.get_trainable_variables()[0]
print(var.to_proto())
# variable_name: "dense_1/kernel:0"
# initializer_name: "dense_1/kernel/Assign"
# snapshot_name: "dense_1/kernel/read:0"
# initial_value_name: "dense_1/random_uniform:0"
# trainable: true
You know that after graph_editor.copy(...)
, your dense_1
name scope is now dense_1b
. Then, all you need is use info.transformed(...)
to get the corresponding ops and tensors, and do the following:
from tensorflow.core.framework import variable_pb2
var_def = variable_pb2.VariableDef()
var_def.variable_name = 'dense_1b/kernel:0'
var_def.initializer_name = "dense_1b/kernel/Assign"
var_def.snapshot_name = "dense_1b/kernel/read:0"
var_def.initial_value_name = "dense_1/random_uniform:0"
var_def.trainable = True
Now, I want to emphasize on the following part of tf.Variable
documentation:
variable_def
: ... recreates the Variable object with its contents, referencing the variable's nodes in the graph, which must already exist. The graph is not changed.
So, the tf.Variable
constructor allows us to create a Variable wrapper on top of existing graph elements. That's exactly what we need:
cloned_var = tf.Variable(variable_def=var_def)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, cloned_var)
Solved!
I kept this answer as simple and specific as possible to show the underlying mechanics of tf.Variables
. You can now easily implement the code for more general case to make new variables automatically.
PS: I hate TensorFlow!