I am trying to parallelize a simple tensorflow (lstm) model prediction in python
The dummy problem I'm facing is:
import multiprocessing as mp
from tensorflow.keras.models import load_model
import numpy as np
# Define the function to parallelize
def fun(model,input_model,particle):
print(f'\nEvaluating particle {particle}\n')
output = model.predict(input_model)
return output
# Define the main
def main():
# Load the model
model = load_model('model_lstm.h5')
model.summary()
# Show it works
dummy_prediction = model.predict(np.random.rand(1, 100, 14))
# Prepare the input for multiprocessing
tuple_aux = []
for i in range(25):
tuple_aux.append(tuple((model,np.random.rand(1,100,14),i)))
# Implement the multiprocessing
with mp.get_context("spawn").Pool(
) as pool:
results = pool.starmap(fun,
tuple_aux)
return results
if __name__ == '__main__':
# Set the start method
mp.set_start_method("spawn", force=True)
# Call the main
results = main()
For which the output before reaching the multiprocessing line is:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 100, 14)] 0
lstm (LSTM) (None, 16) 1984
dense (Dense) (None, 200) 3400
=================================================================
Total params: 5,384
Trainable params: 5,384
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 378ms/step
When the multiprocessing is reached, the code just hangs, after returning this error:
ValueError: Unknown metric function: 'function'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope
I am not using any custom metric, which can be checked because the dummy prediction works, so I believe the trace of the model is being lost somehow. I also tried to directly load the model in the parallelized function, instead of giving it as an input, but still does not work.