I have a class in which I instance a Keras model to perform predictions. This class is organized somewhat like this:
class MyClass():
def __init__(self):
self.model = None
def load(path):
self.model = tf.keras.models.load_model(path_)
def inference(data):
#...
pred = self.model.predict(data)
#...
return pred
I have been trying to run the MyClass.inference
method in parallel. I tried it with joblib.Parallel
:
from joblib import Parallel, delayed
n_jobs = 8
myobj = MyClass()
myobj.load(<Path_to_model>)
results = Parallel(n_jobs=n_jobs )(delayed(myobj.inference)(d) for d in mydata))
But I get the following error: TypeError: cannot pickle 'weakref' object
Apparently, this is a known issue with Keras (https://github.com/tensorflow/tensorflow/issues/34697), that should have been fixed on TF 2.6.0. But after upgrading tensorflow to 2.6.0, I still get the same error. I even tried tf-nightly, as suggested in the same issue, but it also did not work.
I also tried replacing pickle
with dill
, by import dill as pickle
, but it did not fix it.
The only thing that actually worked is replacing the loky
backend in Parallel
by threading
. However, in one scenario I tried using threading
ends up taking pretty much the same time (or a bit slower) as performing the MyClass.inference
calls sequentially.
My question is: what are my options here? Is there any way to run a preloaded keras model's predict
in parallel, such as with other python libs?