2

While integrating pyspark in my application's code-base, I couldn't refer a class's method in a RDD's map method. I duplicated the issue with a simple example which is as follows

Here's a dummy class that, I have defined which just adds a number to every element of RDD derived from a RDD which is a class attribute:

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def add(self, a, b):
        return a + b

    def test_func(self, b):
        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: self.add(l[1], b))
        v_c = v.collect()
        return v_c

test_func() calls map() method on a RDD v, which in-turn calls the add() method on every element of v. Calling test_func() throws the following error:

pickle.PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

Now, when I move the add() method out of class like:

def add(self, a, b):
    return a + b

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def test_func(self, b):

        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: add(l[1], b))
        v_c = v.collect()

        return v_c

Calling test_func() works properly now.

[7, 9, 11]

Why does it happen and how can I pass class methods to a RDD's map() method?

Ajay Brahmakshatriya
  • 8,993
  • 3
  • 26
  • 49
DevanshBheda
  • 91
  • 1
  • 6

1 Answers1

8

This happens because when pyspark tries to serialize your function (to send it to workers), it needs to also serialize the instance of your Test class (because the function you're passing to map has a reference to this instance in self). This instance has a reference to the spark context. You need to make sure that SparkContext and RDDs are not referenced by any object that is serialized and sent to workers. SparkContext needs to live only in the driver.

This should work:

In file testspark.py:

class Test(object):
    def add(self, a, b):
        return a + b

    def test_func(self, a_r, b):
        c_r = a_r.map(lambda l: (l[0], l[1] * 2))
        # now `self` has no reference to the SparkContext()
        v = c_r.map(lambda l: self.add(l[1], b)) 
        v_c = v.collect()
        return v_c

In your main script:

from pyspark import SparkContext
from testspark import Test

sc = SparkContext()
a = [('a', 1), ('b', 2), ('c', 3)]
a_r = sc.parallelize(a)

test = Test()
test.test_func(a_r, 5) # should give [7, 9, 11]
tomas
  • 963
  • 6
  • 19
  • Does this actually work in passing methods to workers? If your class is defined in the python kernel, yes, but if you are trying to access a module ( a .py) you'll have to use `.addPyfiles` as `import`ing to kernel won't work ([so ref](https://stackoverflow.com/questions/43532083/pyspark-import-user-defined-module-or-py-files)). – ohailolcat Mar 28 '19 at 19:46
  • @ohailolcat That's a good point. In this answer I'm assuming that the python environment in the workers is the same as that in the driver (i.e: `testspark.py` is deployed to workers and it's in the `PYTHONPATH`) Otherwise `testspark.py` will need to be included through `.addPyfiles` – tomas Apr 15 '19 at 16:31