0

I am using pyspark(2.3.0) api to create a custom transformer. I want to create a simple transformer which takes any function as a param. I tried doing it with using identity in TypeConverters. The code works. The only problem I am facing is I am unable to save it. It throws the error that the function object is not JSON Serializable. Is there a way to get around that?

I am sending a function object in param because, i want to use it for processing dataframe in _transform method.
So the question is how to modify this code so that I can save the transformer by setting it as a stage in PipelineModel object and using pyspark ML writer of that object.

Here is the Custom Transformer Code I borrowed code from Create a custom Transformer in PySpark ML

functions_dict = dict()
def add_to_dict(f):
    functions_dict[f.__name__] = f
    return f

@add_to_dict
def my_foo( df):
    return df.withColumn("dummy", lit(3))



class MyTransformer(
        Transformer,
        # Credits https://stackoverflow.com/a/52467470
        # by https://stackoverflow.com/users/234944/benjamin-manns
        DefaultParamsReadable, DefaultParamsWritable):

    param_dict = Param(Params._dummy(), "param_dict", "function object",
                      typeConverter=TypeConverters.identity)


    @keyword_only
    def __init__(self, param_dict=None):
        super(MyTransformer, self).__init__()
        self.param_dict = Param(self, "param_dict", "")
        self._setDefault(param_dict=param_dict)
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, param_dict=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setParamdict(self, value):
        return self._set(param_dict=value)

    def getParamdict(self):
        return self.getOrDefault(self.param_dict)

    def _transform(self, dataset):
        df = self.getParamdict()(dataset)
        return df


spark = (SparkSession.builder.appName("spark_run")
                 .enableHiveSupport()
                 .getOrCreate()
                 )

df = spark.sql("select * from table_name")
t1 = MyTransformer(param_dict=functions_dict["my_foo"])
df3 = t1.transform(df)
df3.printSchema()
df3.show()

stages = [t1]
pmodel = PipelineModel(stages=stages)
pmodel.write().overwrite().save("mytransformer")

pmodel1 = PipelineModel.load("mytransformer")
df2 = pmodel1.transform(df)
df2.printSchema()
df2.show()

I am unable to save my PipelineModel. I get the following error.

Traceback (most recent call last):
  File "/Users/code_v1/test.py", line 81, in <module>
    pmodel.write().overwrite().save("mytransformer")
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 135, in save
    self.saveImpl(path)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/pipeline.py", line 226, in saveImpl
    PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/pipeline.py", line 363, in saveImpl
    .getStagePath(stage.uid, index, len(stages), stagesDir))
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 135, in save
    self.saveImpl(path)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 384, in saveImpl
    DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 403, in saveMetadata
    paramMap)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/site-packages/pyspark/ml/util.py", line 427, in _get_metadata_to_save
    return json.dumps(basicMetadata, separators=[',',  ':'])
  File "/Users/anaconda3/envs/spark23/lib/python2.7/json/__init__.py", line 251, in dumps
    sort_keys=sort_keys, **kw).encode(obj)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/json/encoder.py", line 207, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/json/encoder.py", line 270, in iterencode
    return _iterencode(o, 0)
  File "/Users/anaconda3/envs/spark23/lib/python2.7/json/encoder.py", line 184, in default
    raise TypeError(repr(o) + " is not JSON serializable")
TypeError: <function my_foo at 0x7f7f384049d0> is not JSON serializable
Heisenbug
  • 126
  • 6

1 Answers1

0

The short answer is that a function cannot be serialized. However, one way to potentially get around this is to map a string to a valid function, which can be found here. Then, once you load everything, then you may be able to link the string back to the method that you've defined

Hunter
  • 201
  • 3
  • 17