I am using Java to serve a Tensorflow model learned with Python. That model have two inputs. The code is the following :
def predict(float32InputShape: (Long, Long),
float32Inputs: Seq[Seq[Float]],
uint8InputShape: (Long, Long),
uint8Inputs: Seq[Seq[Byte]]
): Array[Float] = {
val float32Input = Tensor.create(
Array(float32InputShape._1, float32InputShape._2),
FloatBuffer.wrap(float32Inputs.flatten.toArray)
)
val uint8Input = Tensor.create(
classOf[UInt8],
Array(uint8InputShape._1, uint8InputShape._2),
ByteBuffer.wrap(uint8Inputs.flatten.toArray)
)
val tfResult = session
.runner()
.feed("serving_default_float32_Input", float32Input)
.feed("serving_default_uint8_Input", uint8Input)
.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])
tfResult
}
What I would like to do is to refactor that method to make it more generic by passing the inputs like with feed_dict in Python. That is, something like :
def predict2(inputs: Map[String, Seq[Seq[Float]]]): Array[Float] = {
...
session
.runner()
.feed(inputs)
...
}
Where the key of the inputs
map would be the name of the input layer. It's not possible to do so with the feed
method unless I make a macro (which I want to avoid).
Is there any way to do this with the Java API of Tensorflow (I'm using TF 2.0) ?
Edit : I found the solution (thanks to @geometrikal answer), the code is in Scala but it shoudn't be too hard to the same in Java.
val runnerWithInputLayers = inputs.foldLeft(session.runner()) {
case (sess, (layerName, array)) =>
val tensor = createTensor(array)
sess.feed(layerName, tensor)
}
val output = runnerWithInputLayers
.fetch(outputLayer)
.run()
.get(0)
.expect(Float.getClass)
It's possible because the .feed
method returns a Session.Runner
with the input layer provided.