6

I am trying the following code which adds a number to every row in an RDD and returns a list of RDDs using PySpark.

from pyspark.context import SparkContext
file  = "file:///home/sree/code/scrap/sample.txt"
sc = SparkContext('local', 'TestApp')
data = sc.textFile(file) 
splits = [data.map(lambda p :  int(p) + i) for i in range(4)]
print splits[0].collect()
print splits[1].collect()
print splits[2].collect()

The content in the input file (sample.txt) is:

1
2
3

I was expecting an output like this (adding the numbers in the rdd with 0, 1, 2 respectively):

[1,2,3]
[2,3,4]
[3,4,5]

whereas the actual output was :

[4, 5, 6]
[4, 5, 6]
[4, 5, 6]

which means that the comprehension used only the value 3 for variable i, irrespective of the range(4).

Why does this behavior happen ?

srjit
  • 526
  • 11
  • 25

2 Answers2

4

It happens because of Python late binding and is not (Py)Spark specific. i will be looked-up when lambda p : int(p) + i is used, not when it is defined. Typically it means when it is called but in this particular context it is when it is serialized to be send to the workers.

You can do for example something like this:

def f(i):
    def _f(x):
        try:
            return int(x) + i
        except:
            pass
    return _f

data = sc.parallelize(["1", "2", "3"])
splits = [data.map(f(i)) for i in range(4)]
[rdd.collect() for rdd in splits]
## [[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]]
zero323
  • 322,348
  • 103
  • 959
  • 935
  • I had tried passing 'p' to a simple external function, and to an inner function (like the one in the answer) called through a lambda, for trial and error purposes. Noticed the correct behavior, when i did this : http://pastebin.com/z7E7wGKx Thank you for replying with the reason why this happens. – srjit Jun 29 '16 at 04:12
  • 1
    worth noting that this happens in just about any language with closures/lambdas, even C# – Austin_Anderson Oct 08 '17 at 16:45
2

This is due to to the fact that lambdas refer to the i via reference! It has nothing to do with spark. See this

You can try this:

a =[(lambda y: (lambda x: y + int(x)))(i) for i in range(4)]
splits = [data.map(a[x]) for x in range(4)]

or in one line

splits = [
    data.map([(lambda y: (lambda x: y + int(x)))(i) for i in range(4)][x])
    for x in range(4)
]
Community
  • 1
  • 1
Himaprasoon
  • 2,609
  • 3
  • 25
  • 46
  • 1
    If you want to use `lambdas` there is a simple trick which avoid nesting: `[lambda x, i=i: i + int(x) for i in range(4)]`. – zero323 Jun 29 '16 at 20:59