5

I have a fined-tuned vgg model and I created the model using tensorflow.keras functional API and saved the model using tf.contrib.saved_model.save_keras_model. So the model is saved with this structure: assets folder which contains saved_model.json file, saved_model.pb file, and the variables folder, which contain checkpoint, variables.data-00000-of-00001 and variables.index.

I can easily load my model in python and get predictions using tf.contrib.saved_model.load_keras_model(saved_model_path), but I have no idea how to load the model in JAVA. I googled a lot and found this How to export Keras .h5 to tensorflow .pb? to export as pb file and then load it up following this link Loading in Java. I was not able to freeze the graph and also I tried to use simple_save but the tensorflow.keras does not support simple_save (AttributeError: module 'tensorflow.contrib.saved_model' has no attribute 'simple_save'). So can someone help me to figure out what steps are needed to load my model (tensorflow.keras functional API model) in JAVA.

Is the saved_model.pb file that I have, good enough to be loaded on the JAVA side? Do I need to create my input/output place holders? Then how can I export it?
I appreciate your help.

Hamid K
  • 983
  • 1
  • 18
  • 40
  • You can use TensorFlow Lite instead of https://www.tensorflow.org/lite –  Dec 12 '18 at 15:49

1 Answers1

2

If you have a model saved in the SavedModel format (which it appears you do, and things like tf.contrib.saved_model.save_keras_model can help create), then in Java you can use SavedModelBundle.load to load and serve it. You do not need to "freeze" the model.

You may find the following useful:

But the basic idea is that your code will look something like:

try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
  try (Tensor<?> input = makeInputTensor();
       Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
  // Use output
  }
}

Where "INPUT_TENSOR" and "OUTPUT_TENSOR" are the names of the input and output nodes in the TensorFlow graph. The saved_model_cli command-line tool installed when you install TensorFlow for Python can show you the names of these tensors in your model.

Note that using the TensorFlow Java API may be more suited to server/desktop applications than using TensorFlow Lite as suggested by another commenter. This is because the TensorFLow Lite runtime, while optimized (in terms of memory footprint etc.) for small devices, cannot export all models yet. While the TensorFlow Java API is using the exact same runtime and thus has the exact same abilities as TensorFlow for Python.

Hope that helps.

Oliv
  • 10,221
  • 3
  • 55
  • 76
ash
  • 6,681
  • 3
  • 18
  • 30
  • I did exactly the same and it worked fine finally when I was using Inception model as a pre-trained model. But when I use VGG model as a base, my model cannot be loaded in JAVA. Have you seen any tutorial which loads the VGG model and finetune it and then loaded in JAVA? I will send you the errors that I get when I try the VGG model. I do appreciate your help. – Hamid K Dec 21 '18 at 14:30
  • This is the error that I get: Matrix size-incompatible: In[0]: [1,8192], In[1]: [25088,256] [[{{node dense/MatMul}} = MatMul[T=DT_FLOAT, _output_shapes=[[?,256]], transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](flatten/Reshape, dense/MatMul/ReadVariableOp)]] – Hamid K Jan 07 '19 at 19:19
  • The Example Link in the response is not valid anymore. – marquies Jul 23 '21 at 16:10