This question is related to this question, which provides a solution that works in Tensorflow 1.15, but doesn't work anymore in TF2
I'm taking part of the code from that question and adapting it slightly (removed the frozen model's multiple inputs and, with it, the need for nest
).
Note: I'm separating the code in blocks, but they're meant to be run as on file (i.e., I won't repeat the unnecessary imports in each block)
First, we generate a frozen graph to use as dummy test network:
import numpy as np
import tensorflow.compat.v1 as tf
def dump_model():
with tf.Graph().as_default() as gf:
x = tf.placeholder(tf.float32, shape=(None, 123), name='x')
c = tf.constant(100, dtype=tf.float32, name='C')
y = tf.multiply(x, c, name='y')
z = tf.add(y, x, name='z')
with tf.gfile.GFile("tmp_net.pb", "wb") as f:
raw = gf.as_graph_def().SerializeToString()
print(type(raw), len(raw))
f.write(raw)
dump_model()
Then, we load the frozen model and wrap it in a Keras Model:
persisted_sess = tf.Session()
with tf.Session().as_default() as session:
with tf.gfile.FastGFile("./tmp_net.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print(persisted_sess.graph.get_name_scope())
for i, op in enumerate(persisted_sess.graph.get_operations()):
tensor = persisted_sess.graph.get_tensor_by_name(op.name + ':0')
print(i, '\t', op.name, op.type, tensor)
x_tensor = persisted_sess.graph.get_tensor_by_name('x:0')
y_tensor = persisted_sess.graph.get_tensor_by_name('y:0')
z_tensor = persisted_sess.graph.get_tensor_by_name('z:0')
from tensorflow.compat.v1.keras.layers import Lambda, InputLayer
from tensorflow.compat.v1.keras import Model
from tensorflow.python.keras.utils import layer_utils
input_x = InputLayer(name='x', input_tensor=x_tensor)
input_x.is_placeholder = True
output_y = Lambda(lambda x: y_tensor, name='output_y')(input_x.output)
output_z = Lambda(lambda x_b: z_tensor, name='output_z')(input_x.output)
base_model_inputs = layer_utils.get_source_inputs(input_x.output)
base_model = Model(base_model_inputs, [output_y, output_z])
Finally, we run the model on some random data and verify that it runs without errors:
y_out, z_out = base_model.predict(np.ones((3, 123), dtype=np.float32))
y_out.shape, z_out.shape
In Tensorflow 1.15.3, the output of the above is ((3, 123), (3, 123))
, however, if I run the same code in Tensorflow 2.1.0, the first two blocks run without a problem, but then the third fails with:
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: y:0
The error seems related to Tensorflow's automatic "compilation" and optimization of functions, but I don't know how to interpret it, what the source of the error is, or how to resolve.
What is the correct way to wrap the frozen model in Tensorflow 2?