2

How can I select an exact number of random rows from a DataFrame efficiently? The data contains an index column that can be used. If I have to use maximum size, what is more efficient, count() or max() on the index column?

Boris
  • 443
  • 8
  • 15

2 Answers2

4

A possible approach is to calculate the number of rows using .count(), then use sample() from python's random library to generate a random sequence of arbitrary length from this range. Lastly use the resulting list of numbers vals to subset your index column.

import random 
def sampler(df, col, records):

  # Calculate number of rows
  colmax = df.count()

  # Create random sample from range
  vals = random.sample(range(1, colmax), records)

  # Use 'vals' to filter DataFrame using 'isin'
  return df.filter(df[col].isin(vals))

Example:

df = sc.parallelize([(1,1),(2,1),
                     (3,1),(4,0),
                     (5,0),(6,1),
                     (7,1),(8,0),
                     (9,0),(10,1)]).toDF(["a","b"])

sampler(df,"a",3).show()
+---+---+
|  a|  b|
+---+---+
|  3|  1|
|  4|  0|
|  6|  1|
+---+---+
mtoto
  • 23,919
  • 4
  • 58
  • 71
  • 1
    thank you for your suggestion. This is something I've came to as well. The reason I didn't want to use this solution, is the use of **count()** method, that is very expensive. – Boris Nov 07 '16 at 06:16
  • you can also cache your `df` then compute `count()` outside the function, or use `agg(max)`. – mtoto Nov 07 '16 at 06:50
  • 1
    Thanks, used your solution in Java. – Boris Nov 08 '16 at 14:46
0

Here's an alternative using Pandas DataFrame.Sample method. This uses the spark applyInPandas method to distribute the groups, available from Spark 3.0.0. This allows you to select an exact number of rows per group.

I've added args and kwargs to the function so you can access the other arguments of DataFrame.Sample.

def sample_n_per_group(n, *args, **kwargs):
    def sample_per_group(pdf):
        return pdf.sample(n, *args, **kwargs)
    return sample_per_group

df = spark.createDataFrame(
    [
        (1, 1.0), 
        (1, 2.0), 
        (2, 3.0), 
        (2, 5.0), 
        (2, 10.0)
    ],
    ("id", "v")
)

(df.groupBy("id")
   .applyInPandas(
        sample_n_per_group(2, random_state=2), 
        schema=df.schema
   )
)

To be aware of the limitations for very large groups, from the documentation:

This function requires a full shuffle. All the data of a group will be loaded into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.

See also here: How take a random row from a PySpark DataFrame?

s_pike
  • 1,710
  • 1
  • 10
  • 22