Tensorflow version: 1.14
Our current setup is using tensorflow estimators to do live NER i.e. perform inference one document at a time. We have 30 different fields to extract, and we run one model per field, so got total of 30 models.
Our current setup uses python multiprocessing to do the inferences in parallel. (The inference is done on CPUs.) This approach reloads the model weights each time a prediction is made.
Using the approach mentioned here, we exported the estimator models as tf.saved_model
. This works as expected in that it does not reload the weights for each request. It also works fine for a single field inference in one process, but doesn't work with multiprocessing. All the processes hang when we make the predict function (predict_fn
in the linked post) call.
This post is related, but not sure how to adapt it for saved model.
Importing tensorflow individually for each of the predictors did not work either:
class SavedModelPredictor():
def __init__(self, model_path):
import tensorflow as tf
self.predictor_fn = tf.contrib.predictor.from_saved_model(model_path)
def predictor_fn(self, input_dict):
return self.predictor_fn(input_dict)
How to make tf.saved_model
work with multiprocessing?