11

I have trained a Keras model based on this repo.

After the training I save the model as checkpoint files like this:

 sess=tf.keras.backend.get_session() 
 saver = tf.train.Saver()
 saver.save(sess, current_run_path + '/checkpoint_files/model_{}.ckpt'.format(date))

Then I restore the graph from the checkpoint files and freeze it using the standard tf freeze_graph script. When I want to restore the frozen graph I get the following error:

Input 0 of node Conv_BN_1/cond/ReadVariableOp/Switch was passed float from Conv_BN_1/gamma:0 incompatible with expected resource

How can I fix this issue?

Edit: My problem is related to this question. Unfortunately, I can't use the workaround.

Edit 2: I have opened an issue on github and created a gist to reproduce the error. https://github.com/keras-team/keras/issues/11032

ninja
  • 397
  • 4
  • 16

3 Answers3

13

Just resolved the same issue. I connected this few answers: 1, 2, 3 and realized that issue originated from batchnorm layer working state: training or learning. So, in order to resolve that issue you just need to place one line before loading your model:

keras.backend.set_learning_phase(0)

Complete example, to export model

import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3


def freeze_graph(graph, session, output):
    with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
        graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)

tf.keras.backend.set_learning_phase(0) # this line most important

base_model = InceptionV3()

session = tf.keras.backend.get_session()

INPUT_NODE = base_model.inputs[0].op.name
OUTPUT_NODE = base_model.outputs[0].op.name
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])

to load *.pb model:

from PIL import Image
import numpy as np
import tensorflow as tf

# https://i.imgur.com/tvOB18o.jpg
im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
im = np.array(im) / 255.0
im = im[None, ...]

graph_def = tf.GraphDef()

with tf.gfile.GFile("frozen_model.pb", "rb") as f:
    graph_def.ParseFromString(f.read())

graph = tf.Graph()

with graph.as_default():
    net_inp, net_out = tf.import_graph_def(
        graph_def, return_elements=["input_1", "predictions/Softmax"]
    )
    with tf.Session(graph=graph) as sess:
        out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
        print(np.argmax(out))
Ivan Talalaev
  • 6,014
  • 9
  • 40
  • 49
  • Thanks a lot, only solution that worked for me after reading tens of others. – Kerem T Apr 02 '19 at 14:38
  • 1
    The problem with `set_learning_phase(0)` is that it puts the batch norm layers into inference mode, and thus if you do any training they will not update as normal. – geometrikal Aug 08 '19 at 00:58
  • Don't forget to clear the session before setting the learning phase to 0: `keras.backend.clear_session()` – tsveti_iko Nov 18 '19 at 15:49
2

This is bug with Tensorflow 1.1x and as another answer stated, it is because of the internal batch norm learning vs inference state. In TF 1.14.0 you actually get a cryptic error when trying to freeze a batch norm layer.

Using set_learning_phase(0) will put the batch norm layer (and probably others like dropout) into inference mode and thus the batch norm layer will not work during training, leading to reduced accuracy.

My solution is this:

  1. Create the model using a function (do not use K.set_learning_phase(0)):
def create_model():
    inputs = Input(...)
    ...
    return model

model = create_model()
  1. Train model
  2. Save weights: model.save_weights("weights.h5")
  3. Clear session (important so layer names are the same) and set learning phase to 0:
K.clear_session()
K.set_learning_phase(0)
  1. Recreate model and load weights:
model = create_model()
model.load_weights("weights.h5")
  1. Freeze as before
geometrikal
  • 3,195
  • 2
  • 29
  • 40
0

Thanks for pointing the main issue! I found that keras.backend.set_learning_phase(0) to be not working sometimes, at least in my case.

Another approach might be: for l in keras_model.layers: l.trainable = False

Kerem T
  • 260
  • 4
  • 6