4

I'm having problems running my PySpark UDFs in a distributed way, e.g. via Databricks Connect.

For example:

import pyspark.sql.functions as f

class MyClass(object):
    def __init__(self, number_string):
        self.number = int(number_string)

    def simple_fun(self, num1, num2):
        return self.number + num1 + num2

    def nested_udf(self):
        @f.udf('string')
        def nested_fun(num1, num2):
            self.simple_fun(num1, num2)

        return nested_fun

Then, with the following code:

from pyspark.sql import SparkSession, Row
from pyspark.sql import types as t
from demo.demoModule import MyClass

spark = return SparkSession.builder.getOrCreate()

rdd = spark.sparkContext.parallelize(
    [
        Row(
            num1=1,
            num2=1
        ),
        Row(
            num1=1,
            num2=2
        ),
        Row(
            num1=2,
            num2=2
        ),
        Row(
            num1=2,
            num2=3
        ),
    ]
)
schema = t.StructType(
    [
        t.StructField("num1", t.IntegerType(), True),
        t.StructField("num2", t.IntegerType(), True)
    ]
)

input_df = spark.createDataFrame(rdd, schema)

my_class = MyClass("10")
result = input_df.withColumn("new_col", my_class.nested_udf()("num1", "num2"))

result.columns
result.show()

I get no issues if I run it locally using just pyspark but if I try to run it with databricks-connect I get the following error:

pyspark.serializers.SerializationError: Caused by Traceback (most recent call last): File "/databricks/spark/python/pyspark/serializers.py", line 165, in _read_with_length return self.loads(obj) File "/databricks/spark/python/pyspark/serializers.py", line 469, in loads return pickle.loads(obj, encoding=encoding) ModuleNotFoundError: No module named 'demo'

I tried this solution by tweaking my code:

def take_df(self, df, colname1, colname2):
    my_udf = f.udf(self.simple_fun, "string")
    return df.withColumn("new_col", my_udf(colname1, colname2))

...

my_class = MyClass("10")
result = my_class.take_df(input_df, "num1", "num2")

But I get the same error. Any hints on how to get around it are welcome!

Kasia Kulma
  • 1,683
  • 1
  • 14
  • 39

0 Answers0