I created the following model:
class EquipmentEmbeddingEndpoint(mlflow.pyfunc.PythonModel):
def load_context(self, context):
self.identifiers_df = get_identifier_information()
def predict(self, context, model_input):
print('got in predict function')
return {"output": "dummy"}
which calls the following get_identifier_information() which exists in the same Python notebook:
def get_identifier_information():
identifiers_df = spark.sql(f"""
SELECT * FROM third_party_products tpp
""")
return identifiers_df
This is how I log the model:
import numpy as np
with mlflow.start_run():
sample_inputs = np.array(["btr197", "ao smith"])
mlflow.pyfunc.log_model("test_equipment_embedding",
python_model=EquipmentEmbeddingEndpoint(),
registered_model_name='test_equipment_embedding',
input_example=sample_inputs,
)
And this is the error I am running into:
RuntimeError: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
Could I please get some help figuring this out? Thanks in advance!