15

Lets start with a simple function which always returns a random integer:

import numpy as np

def f(x):
    return np.random.randint(1000)

and a RDD filled with zeros and mapped using f:

rdd = sc.parallelize([0] * 10).map(f)

Since above RDD is not persisted I expect I'll get a different output every time I collect:

> rdd.collect()
[255, 512, 512, 512, 255, 512, 255, 512, 512, 255]

If we ignore the fact that distribution of values doesn't really look random it is more or less what happens. Problem starts we we when take only a first element:

assert len(set(rdd.first() for _ in xrange(100))) == 1

or

assert len(set(tuple(rdd.take(1)) for _ in xrange(100))) == 1

It seems to return the same number each time. I've been able to reproduce this behavior on two different machines with Spark 1.2, 1.3 and 1.4. Here I am using np.random.randint but it behaves the same way with random.randint.

This issue, same as non-exactly-random results with collect, seems to be Python specific and I couldn't reproduce it using Scala:

def f(x: Int) = scala.util.Random.nextInt(1000)

val rdd = sc.parallelize(List.fill(10)(0)).map(f)
(1 to 100).map(x => rdd.first).toSet.size

rdd.collect()

Did I miss something obvious here?

Edit:

Turns out the source of the problem is Python RNG implementation. To quote official documentation:

The functions supplied by this module are actually bound methods of a hidden instance of the random.Random class. You can instantiate your own instances of Random to get generators that don’t share state.

I assume NumPy works the same way and rewriting f using RandomState instance as follows

import os
import binascii

def f(x, seed=None):
    seed = (
        seed if seed is not None 
        else int(binascii.hexlify(os.urandom(4)), 16))
    rs = np.random.RandomState(seed)
    return rs.randint(1000)

makes it slower but solves the problem.

While above explains not random results from collect I still don't understand how it affects first / take(1) between multiple actions.

zero323
  • 322,348
  • 103
  • 959
  • 935
  • Just to clarify: if I'm using numpy's random function in Spark, It always choose the same results in each partition? How can I use np.random.choice so it would be random? – member555 Aug 29 '15 at 15:34
  • _It always choose the same results in each partition_ - not exactly, but values computed on a single worker won't be independent. _How can I use np.random.choice so it would be random?_ - I've already described solution in an edit. You should use a separate state. Since it is rather expensive you may want to initialize it once per partition. – zero323 Aug 31 '15 at 00:55
  • Can you explain me in more details what is the problem? why does python's shared state is a problem? – member555 Aug 31 '15 at 20:48
  • @member555 Well, it is broad question. Long story short RNGs like are actually pseudorandom and generate a deterministic sequence of values. The same value access multiple times by different threads before values is updated. A simple [SO search](http://stackoverflow.com/search?q=[python]+random+multiprocessing) should provide you with more details. – zero323 Oct 05 '15 at 13:23
  • 1
    This solved my problem, but shouldn't the Edit part be part of the answer? – Akavall Jul 05 '16 at 21:06
  • @Akavall It probably should but there's 9 months between these two. I figured out a half of the issue pretty quickly and hoped that someone else will fill the blanks. I'll try to reorganize this when I have a spare moment. And I am glad it helped. – zero323 Jul 05 '16 at 22:23

3 Answers3

4

So the actual problem here is relatively simple. Each subprocess in Python inherits its state from its parent:

len(set(sc.parallelize(range(4), 4).map(lambda _: random.getstate()).collect()))
# 1

Since parent state has no reason to change in this particular scenario and workers have a limited lifespan, state of every child will be exactly the same on each run.

zero323
  • 322,348
  • 103
  • 959
  • 935
3

This seems to be a bug (or feature) of randint. I see the same behavior, but as soon as I change the f, the values do indeed change. So, I'm not sure of the actual randomness of this method....I can't find any documentation, but it seems to be using some deterministic math algorithm instead of using more variable features of the running machine. Even if I go back and forth, the numbers seem to be the same upon returning to the original value...

Justin Pihony
  • 66,056
  • 18
  • 147
  • 180
  • 1
    It is pseudorandom generator implementing Mersenne Twister but it shouldn't be a problem. Problem is definitely related to shared `Random` class (I've edited the question to reflect that) but how it affects `first` output still puzzles me. – zero323 Aug 09 '15 at 15:57
3

For my use case, most of the solution was buried in an edit at the bottom of the question. However, there is another complication: I wanted to use the same function to generate multiple (different) random columns. It turns out that Spark has an assumption that the output of a UDF is deterministic, which means that it can skip later calls to the same function with the same inputs. For functions that return random output this is obviously not what you want.

To work around this, I generated a separate seed column for every random column that I wanted using the built-in PySpark rand function:

import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType
import numpy as np

@F.udf(IntegerType())
def my_rand(seed):
    rs = np.random.RandomState(seed)
    return rs.randint(1000)

seed_expr = (F.rand()*F.lit(4294967295).astype('double')).astype('bigint')
my_df = (
    my_df
    .withColumn('seed_0', seed_expr)
    .withColumn('seed_1', seed_expr)
    .withColumn('myrand_0', my_rand(F.col('seed_0')))
    .withColumn('myrand_1', my_rand(F.col('seed_1')))
    .drop('seed_0', 'seed_1')
)

I'm using the DataFrame API rather than the RDD of the original problem because that's what I'm more familiar with, but the same concepts presumably apply.

NB: apparently it is possible to disable the assumption of determinism for Scala Spark UDFs since v2.3: https://jira.apache.org/jira/browse/SPARK-20586.

abeboparebop
  • 7,396
  • 6
  • 37
  • 46