You can create a column with random values and use row_number
to filter 1M random samples for each label:
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.window import Window
n = 333333 # number of samples
df = df.withColumn('rand_col', F.rand())
sample_df1 = df.withColumn("row_num",row_number().over(Window.partitionBy("label")\
.orderBy("rand_col"))).filter(col("row_num")<=3)\
.drop("rand_col", "row_num")
sample_df1.groupBy("label").count().show()
This will always give you 1M samples for each label.
Another way of doing this is by stratified sampling using spark's stat.sampleBy
n = 333333
seed = 12345
# Creating a dictionary of fractions for eacch label
fractions = df.groupBy("label").count().withColumn("required_n", n/col("count"))\
.drop("count").rdd.collectAsMap()
sample_df2 = df.stat.sampleBy("label", fractions, seed)
sample_df2.groupBy("label").count().show()
sampleBy
however results in an approximate solution depending on the run and does not guarantee an exact number of records for each label.
Example dataframe:
schema = StructType([StructField("id", IntegerType()), StructField("label", IntegerType())])
data = [[1, 2], [1, 2], [1, 3], [2, 3], [1, 2],[1, 1], [1, 2], [1, 3], [2, 2], [1, 1],[1, 2], [1, 2], [1, 3], [2, 3], [1, 1]]
df = spark.createDataFrame(data,schema=schema)
df.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
| 1| 3|
| 2| 7|
| 3| 5|
+-----+-----+
Method 1:
# Sampling 3 records from each label
n = 3
# Assign a column with random values
df = df.withColumn('rand_col', F.rand())
sample_df1 = df.withColumn("row_num",row_number().over(Window.partitionBy("label")\
.orderBy("rand_col"))).filter(col("row_num")<=3)\
.drop("rand_col", "row_num")
sample_df1.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
| 1| 3|
| 2| 3|
| 3| 3|
+-----+-----+
Method 2:
# Sampling 3 records from each label
n = 3
seed = 12
fractions = df.groupBy("label").count().withColumn("required_n", n/col("count"))\
.drop("count").rdd.collectAsMap()
sample_df2 = df.stat.sampleBy("label", fractions, seed)
sample_df2.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
| 1| 3|
| 2| 3|
| 3| 4|
+-----+-----+
As you can see, sampleBy
tends to give you an approximately equal distribution. But not exactly. I'd prefer Method 1 for your problem.