We are currently testing a prediction engine based on Spark's implementation of LDA in Python: https://spark.apache.org/docs/2.2.0/ml-clustering.html#latent-dirichlet-allocation-lda https://spark.apache.org/docs/2.2.0/api/python/pyspark.ml.html#pyspark.ml.clustering.LDA (we are using the pyspark.ml package, not pyspark.mllib)
We were able to succesfuly train a model on a Spark cluster (using Google Cloud Dataproc). Now we are trying to use the model to serve real-time predictions as an API (e.g. flask application).
What would be the best approach to achieve so?
Our main pain point is that it seems we need to bring back the whole Spark environnement in order to load the trained model and run the transform. So far we've tried running Spark in local mode for each received request but this approach gave us:
- Poor performances (time to spin-up the SparkSession, load the models, run the transform...)
- Poor scalability (inability to process concurrent requests)
The whole approach seems quite heavy, would there be a simpler alternative, or even one that would not need to imply Spark at all?
Bellow are simplified code of the training and prediction steps.
Training code
def train(input_dataset):
conf = pyspark.SparkConf().setAppName("lda-train")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
# Generate count vectors
count_vectorizer = CountVectorizer(...)
vectorizer_model = count_vectorizer.fit(input_dataset)
vectorized_dataset = vectorizer_model.transform(input_dataset)
# Instantiate LDA model
lda = LDA(k=100, maxIter=100, optimizer="em", ...)
# Train LDA model
lda_model = lda.fit(vectorized_dataset)
# Save models to external storage
vectorizer_model.write().overwrite().save("gs://...")
lda_model.write().overwrite().save("gs://...")
Prediction code
def predict(input_query):
conf = pyspark.SparkConf().setAppName("lda-predict").setMaster("local")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
# Load models from external storage
vectorizer_model = CountVectorizerModel.load("gs://...")
lda_model = DistributedLDAModel.load("gs://...")
# Run prediction on the input data using the loaded models
vectorized_query = vectorizer_model.transform(input_query)
transformed_query = lda_model.transform(vectorized_query)
...
spark.stop()
return transformed_query