I am using pyspark(2.3.0) api to create a custom transformer. I want to create a simple transformer which takes any function as a param. I tried doing it with using identity in TypeConverters. The code works. The only problem I am facing is I am unable to save it. It throws the error that the function object is not JSON Serializable. Is there a way to get around that?
I am sending a function object in param because, i want to use it for processing dataframe in _transform
method.
So the question is how to modify this code so that I can save the transformer by setting it as a stage in PipelineModel
object and using pyspark ML writer of that object.
Here is the Custom Transformer Code I borrowed code from Create a custom Transformer in PySpark ML
functions_dict = dict()
def add_to_dict(f):
functions_dict[f.__name__] = f
return f
@add_to_dict
def my_foo( df):
return df.withColumn("dummy", lit(3))
class MyTransformer(
Transformer,
# Credits https://stackoverflow.com/a/52467470
# by https://stackoverflow.com/users/234944/benjamin-manns
DefaultParamsReadable, DefaultParamsWritable):
param_dict = Param(Params._dummy(), "param_dict", "function object",
typeConverter=TypeConverters.identity)
@keyword_only
def __init__(self, param_dict=None):
super(MyTransformer, self).__init__()
self.param_dict = Param(self, "param_dict", "")
self._setDefault(param_dict=param_dict)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, param_dict=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def setParamdict(self, value):
return self._set(param_dict=value)
def getParamdict(self):
return self.getOrDefault(self.param_dict)
def _transform(self, dataset):
df = self.getParamdict()(dataset)
return df
spark = (SparkSession.builder.appName("spark_run")
.enableHiveSupport()
.getOrCreate()
)
df = spark.sql("select * from table_name")
t1 = MyTransformer(param_dict=functions_dict["my_foo"])
df3 = t1.transform(df)
df3.printSchema()
df3.show()
stages = [t1]
pmodel = PipelineModel(stages=stages)
pmodel.write().overwrite().save("mytransformer")
pmodel1 = PipelineModel.load("mytransformer")
df2 = pmodel1.transform(df)
df2.printSchema()
df2.show()
I am unable to save my PipelineModel. I get the following error.
Traceback (most recent call last):
File "/Users/code_v1/test.py", line 81, in <module>
pmodel.write().overwrite().save("mytransformer")
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 135, in save
self.saveImpl(path)
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/pipeline.py", line 226, in saveImpl
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/pipeline.py", line 363, in saveImpl
.getStagePath(stage.uid, index, len(stages), stagesDir))
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 135, in save
self.saveImpl(path)
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 384, in saveImpl
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 403, in saveMetadata
paramMap)
File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 427, in _get_metadata_to_save
return json.dumps(basicMetadata, separators=[',', ':'])
File "/Users/anaconda3/envs/spark23/lib/python2.7/json/__init__.py", line 251, in dumps
sort_keys=sort_keys, **kw).encode(obj)
File "/Users/anaconda3/envs/spark23/lib/python2.7/json/encoder.py", line 207, in encode
chunks = self.iterencode(o, _one_shot=True)
File "/Users/anaconda3/envs/spark23/lib/python2.7/json/encoder.py", line 270, in iterencode
return _iterencode(o, 0)
File "/Users/anaconda3/envs/spark23/lib/python2.7/json/encoder.py", line 184, in default
raise TypeError(repr(o) + " is not JSON serializable")
TypeError: <function my_foo at 0x7f7f384049d0> is not JSON serializable