1

I need to add the decode_batch_predictions() method to the output of the Keras Handwriting Recognition OCR model. The reason for that is that I want to convert the model to TF Lite and I want the output to be decoded since I didn't find any way to decode the output on TF Lite in Android. I already saw a similar post for a similar Keras model but it wouldn't work for this model. I have not much knowledge in Python so it's difficult for me to adapt the answers on that post for this model so I would really appreciate any help, thanks!

I tried using the code from that post but it wouldn't work

1 Answers1

0

In the notebook for model given in your link, make the following changes after prediction_model:

prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
) # This line is present in the handwriting_recognition notebook.

def CTCDecoder():
  def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_length]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        #print(res)
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

  return tf.keras.layers.Lambda(decode_batch_predictions, name='decode')

decoded_pred_model = keras.models.Model(prediction_model.input, outputs=CTCDecoder()(prediction_model.output))

Convert the decoded_pred_model to a .tflite and use it in android.

MSS
  • 3,306
  • 1
  • 19
  • 50
  • Thanks @MSS, In this line `results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_length]` I get following error: ` ValueError: Exception encountered when calling layer "decode" (type Lambda). Shape must be rank 1 but is rank 0 for '{{node decode/CTCGreedyDecoder}} = CTCGreedyDecoder[T=DT_FLOAT, blank_index=-1, merge_repeated=true](decode/Log, decode/Cast)' with input shapes: [32,?,81], []. Call arguments received by layer "decode" (type Lambda): • inputs=tf.Tensor(shape=(None, 32, 81), dtype=float32) • mask=None • training=False ` – Mehdi Karbalai Dec 13 '22 at 13:32
  • Is this error coming in android ? – MSS Dec 13 '22 at 14:51