2

Tensorflow version: 1.14

Our current setup is using tensorflow estimators to do live NER i.e. perform inference one document at a time. We have 30 different fields to extract, and we run one model per field, so got total of 30 models.

Our current setup uses python multiprocessing to do the inferences in parallel. (The inference is done on CPUs.) This approach reloads the model weights each time a prediction is made.

Using the approach mentioned here, we exported the estimator models as tf.saved_model. This works as expected in that it does not reload the weights for each request. It also works fine for a single field inference in one process, but doesn't work with multiprocessing. All the processes hang when we make the predict function (predict_fn in the linked post) call.

This post is related, but not sure how to adapt it for saved model.

Importing tensorflow individually for each of the predictors did not work either:

class SavedModelPredictor():

    def __init__(self, model_path):
        import tensorflow as tf
        self.predictor_fn = tf.contrib.predictor.from_saved_model(model_path)

    def predictor_fn(self, input_dict):
        return self.predictor_fn(input_dict)

How to make tf.saved_model work with multiprocessing?

arun
  • 10,685
  • 6
  • 59
  • 81

2 Answers2

3

Ray Serve, ray's model serving solution, also support offline batching. You can wrap your model in Ray Serve's backend and scale it to the number replica you want.

from ray import serve
client = serve.start()

class MyTFModel:
    def __init__(self, model_path):
        self.model = ... # load model

    @serve.accept_batch
    def __call__(self, input_batch):
        assert isinstance(input_batch, list)

        # forward pass
        self.model([item.data for item in input_batch])

        # return a list of response
        return [...]

client.create_backend("tf", MyTFModel, 
    # configure resources
    ray_actor_options={"num_cpus": 2, "num_gpus": 1},
    # configure replicas
    config={
        "num_replicas": 2, 
        "max_batch_size": 24,
        "batch_wait_timeout": 0.5
    }
)
client.create_endpoint("tf", backend="tf")
handle = serve.get_handle("tf")

# perform inference on a list of input
futures = [handle.remote(data) for data in fields]
result = ray.get(futures)

Try it out with the nightly wheel and here's the tutorial: https://docs.ray.io/en/master/serve/tutorials/batch.html

Edit: updated the code sample for Ray 1.0

Simon Mo
  • 76
  • 3
2

Ok, so the approach outlined in this answer with ray worked.

Built a class like this, which loads the model on init and exposes a function run to perform prediction:

import tensorflow as tf
import ray

ray.init()

@ray.remote
class MyModel(object):

    def __init__(self, field, saved_model_path):
        self.field = field
        # load the model once in the constructor
        self.predictor_fn = tf.contrib.predictor.from_saved_model(saved_model_path)

    def run(self, df_feature, *args):
        # ...
        # code to perform prediction using self.predictor_fn
        # ...
        return self.field, list_pred_string, list_pred_proba

Then used the above in the main module as:

# form a dictionary with key 'field' and value MyModel
model_dict = {}
for field in fields:
    export_dir = f"saved_model/{field}"
    subdirs = [x for x in Path(export_dir).iterdir()
               if x.is_dir() and 'temp' not in str(x)]
    latest = str(sorted(subdirs)[-1])
    model_dict[field] = MyModel.remote(field, latest)

Then used the above model dictionary to do predictions like this:

results = ray.get([model_dict[field].run.remote(df_feature) for field in fields])

Update:

While this approach works, found that running estimators in parallel with multiprocessing is faster than running predictors in parallel with ray. This is especially true for large document sizes. It looks like the predictor approach might work well for small number of dimensions and when the input data is not large. Maybe an approach like mentioned here might be better for our use case.

arun
  • 10,685
  • 6
  • 59
  • 81
  • “Faster to use multiproc” but in my experience you can’t import tf on multiple processes – Kermit May 08 '22 at 00:27