There is a method for converting constants back to trainable variables in TensorFlow, via the Graph Editor. However, you will need to specify the nodes to convert, as I'm not sure if there is a way to automatically detect this in a robust manner.
Here are the steps:
Step 1: Load frozen graph
We load our .pb
file into a graph object.
import tensorflow as tf
# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
with tf.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
tf_graph = load_pb('frozen_graph.pb')
Step 2: Find constants that need conversion
Here are 2 ways to list the names of nodes in your graph:
- Use this script to print them
print([n.name for n in tf_graph.as_graph_def().node])
The nodes you'll want to convert are likely named something along the lines of "Const". To be sure, it is a good idea to load your graph in Netron to see which tensors are storing the trainable weights. Oftentimes, it is safe to assume that all const nodes were once variables.
Once you have these nodes identified, let's store their names into a list:
to_convert = [...] # names of tensors to convert
Step 3: Convert constants to variables
Run this code to convert your specified constants. It essentially creates corresponding variables for each constant and uses GraphEditor to unhook the constants from the graph, and hook the variables on.
import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge
const_var_name_pairs = []
with tf_graph.as_default() as g:
for name in to_convert:
tensor = g.get_tensor_by_name('{}:0'.format(name))
with tf.Session() as sess:
tensor_as_numpy_array = sess.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
# Create TensorFlow variable initialized by values of original const.
var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \
initializer=tf.constant_initializer(tensor_as_numpy_array))
# We want to keep track of our variables names for later.
const_var_name_pairs.append((name, var_name))
# At this point, we added a bunch of tf.Variables to the graph, but they're
# not connected to anything.
# The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
# the outputs of our newly created Variables.
for const_name, var_name in const_var_name_pairs:
const_op = g.get_operation_by_name(const_name)
var_reader_op = g.get_operation_by_name(var_name + '/read')
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
Step 4: Save result as .ckpt
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = tf.train.Saver().save(sess, 'model.ckpt')
print("Model saved in path: %s" % save_path)
And viola! You should be done at this point :) I was able to get this working myself, and verified that the model weights are preserved--the only difference is that the graph is now trainable. Please let me know if there are any issues.