0

I am using Java to serve a Tensorflow model learned with Python. That model have two inputs. The code is the following :

  def predict(float32InputShape: (Long, Long),
              float32Inputs: Seq[Seq[Float]],
              uint8InputShape: (Long, Long),
              uint8Inputs: Seq[Seq[Byte]]
             ): Array[Float] = {

    val float32Input = Tensor.create(
      Array(float32InputShape._1, float32InputShape._2),
      FloatBuffer.wrap(float32Inputs.flatten.toArray)
    )
    val uint8Input = Tensor.create(
      classOf[UInt8],
      Array(uint8InputShape._1, uint8InputShape._2),
      ByteBuffer.wrap(uint8Inputs.flatten.toArray)
    )

    val tfResult = session
      .runner()
      .feed("serving_default_float32_Input", float32Input)
      .feed("serving_default_uint8_Input", uint8Input)
      .fetch("PartitionedCall")
      .run()
      .get(0)
      .expect(classOf[java.lang.Float])

    tfResult
  }

What I would like to do is to refactor that method to make it more generic by passing the inputs like with feed_dict in Python. That is, something like :

    def predict2(inputs: Map[String, Seq[Seq[Float]]]): Array[Float] = {
      ...
      session
        .runner()
        .feed(inputs)
      ...
  }

Where the key of the inputs map would be the name of the input layer. It's not possible to do so with the feed method unless I make a macro (which I want to avoid).

Is there any way to do this with the Java API of Tensorflow (I'm using TF 2.0) ?

Edit : I found the solution (thanks to @geometrikal answer), the code is in Scala but it shoudn't be too hard to the same in Java.

    val runnerWithInputLayers = inputs.foldLeft(session.runner()) {
      case (sess, (layerName, array)) =>
        val tensor = createTensor(array)
        sess.feed(layerName, tensor)
    }

    val output = runnerWithInputLayers
      .fetch(outputLayer)
      .run()
      .get(0)
      .expect(Float.getClass)

It's possible because the .feed method returns a Session.Runner with the input layer provided.

B.B
  • 51
  • 5

1 Answers1

0

You can feed each in a loop. If not so familiar with java script but pseudo-code is something like

e.g.

val tfResult = session.runner()
for(key, value : inputs) {
    tfResult = tfResult(key, value)
}
tfResult = tfResult.fetch("PartitionedCall")
  .run()
  .get(0)
  .expect(classOf[java.lang.Float])

Remember you can break up the function chain at any point, e.g. result = foo.bar().baz().qux() can be written temp = foo.bar().baz(); result = temp.qux()

geometrikal
  • 3,195
  • 2
  • 29
  • 40
  • Actually my question was for Java (not Javascript); the code I provided may be confusing because it's Scala (which is inter-operable with Java). I tried `tfResult = tfResult(key, value)` but it doesn't work. – B.B Oct 15 '19 at 11:31
  • @B.B what I provided was pseudo-code, i.e. the logic but not the implementation. You will have to look up how to iterate through a map with Scala and implement that. – geometrikal Oct 15 '19 at 20:25
  • See here https://stackoverflow.com/questions/1066589/iterate-through-a-hashmap – geometrikal Oct 15 '19 at 20:31