How can I select an exact number of random rows from a DataFrame efficiently? The data contains an index column that can be used. If I have to use maximum size, what is more efficient, count() or max() on the index column?
-
can't you just use `df.sample()` ? – mtoto Nov 06 '16 at 21:05
-
@mtoto sample() returns an approximate number, but algorithmics request an exact number in certain scenarios. – Boris Nov 06 '16 at 21:10
2 Answers
A possible approach is to calculate the number of rows using .count()
, then use sample()
from python
's random library to generate a random sequence of arbitrary length from this range. Lastly use the resulting list of numbers vals
to subset your index column.
import random
def sampler(df, col, records):
# Calculate number of rows
colmax = df.count()
# Create random sample from range
vals = random.sample(range(1, colmax), records)
# Use 'vals' to filter DataFrame using 'isin'
return df.filter(df[col].isin(vals))
Example:
df = sc.parallelize([(1,1),(2,1),
(3,1),(4,0),
(5,0),(6,1),
(7,1),(8,0),
(9,0),(10,1)]).toDF(["a","b"])
sampler(df,"a",3).show()
+---+---+
| a| b|
+---+---+
| 3| 1|
| 4| 0|
| 6| 1|
+---+---+

- 23,919
- 4
- 58
- 71
-
1thank you for your suggestion. This is something I've came to as well. The reason I didn't want to use this solution, is the use of **count()** method, that is very expensive. – Boris Nov 07 '16 at 06:16
-
you can also cache your `df` then compute `count()` outside the function, or use `agg(max)`. – mtoto Nov 07 '16 at 06:50
-
1
Here's an alternative using Pandas DataFrame.Sample method. This uses the spark applyInPandas
method to distribute the groups, available from Spark 3.0.0. This allows you to select an exact number of rows per group.
I've added args
and kwargs
to the function so you can access the other arguments of DataFrame.Sample
.
def sample_n_per_group(n, *args, **kwargs):
def sample_per_group(pdf):
return pdf.sample(n, *args, **kwargs)
return sample_per_group
df = spark.createDataFrame(
[
(1, 1.0),
(1, 2.0),
(2, 3.0),
(2, 5.0),
(2, 10.0)
],
("id", "v")
)
(df.groupBy("id")
.applyInPandas(
sample_n_per_group(2, random_state=2),
schema=df.schema
)
)
To be aware of the limitations for very large groups, from the documentation:
This function requires a full shuffle. All the data of a group will be loaded into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.
See also here: How take a random row from a PySpark DataFrame?

- 1,710
- 1
- 10
- 22