I want to carry out a stratified sampling from a data frame on PySpark. There is a sampleBy(col, fractions, seed=None)
function, but it seems to only use one column as a strata. Is there any way to use multiple columns as a strata?

- 2,512
- 1
- 16
- 22
2 Answers
based on the answer here
after converting it to python, I think an answer might look like:
#create a dataframe to use
df = sc.parallelize([ (1,1234,282),(1,1396,179),(2,8620,178),(3,1620,191),(3,8820,828) ] ).toDF(["ID","X","Y"])
#we are going to use the first two columns as our key (strata)
#assign sampling percentages to each key # you could do something cooler here
fractions = df.rdd.map(lambda x: (x[0],x[1])).distinct().map(lambda x: (x,0.3)).collectAsMap()
#setup how we want to key the dataframe
kb = df.rdd.keyBy(lambda x: (x[0],x[1]))
#create a dataframe after sampling from our newly keyed rdd
#note, if the sample did not return any values you'll get a `ValueError: RDD is empty` error
sampleddf = kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns)
sampleddf.show()
+---+----+---+
| ID| X| Y|
+---+----+---+
| 1|1234|282|
| 1|1396|179|
| 3|1620|191|
+---+----+---+
#other examples
kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID| X| Y|
+---+----+---+
| 2|8620|178|
+---+----+---+
kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID| X| Y|
+---+----+---+
| 1|1234|282|
| 1|1396|179|
+---+----+---+
Is this the kind of thing you were looking for?

- 1
- 1

- 3,070
- 19
- 35
-
Yes! This is what I was looking for. I wanted to change the sample fractions for different keys, but it will be trivial after looking at this. – ysakamoto May 09 '17 at 21:00
-
Also there's a typo on your 2nd line `sc.parallelize([] ...` – ysakamoto May 09 '17 at 21:00
James Tobin's solution above works fine with the example presented but I had difficulties replicating the approach on my dataset (nearly 2 million records). Strange java related runtime error occured and I was not able to pinpoint what the issue was (I was running pyspark in local mode).
An alternative approach would be to flex the approach of stratified sampling based on a single column. For this, we create a new (temporary) column which is a merger of the values present in the multiple columns on which we originally wanted to apply stratified sampling. Then, we perform the splits and delete the merged column
in the resulting splits.
def get_stratified_split_multiple_columns(input_df, col_name1, col_name2, seed_value=1234, train_frac=0.6):
"""
Following the approach of stratified sampling based on a single column as
presented at https://stackoverflow.com/a/47672336/530399 .
However, this time our single column is going to be a merger
of the values present in multiple columns (`col_name1` and `col_name2`).
Note that pyspark split is not exact. Therefore, if there are too few
examples per category, it can be that none of the examples go to
validation/test split and therefore result in error.
"""
merged_col_name = "both_labels"
input_df = input_df.withColumn(merged_col_name, F.concat(F.col(col_name1), F.lit('_#_@_#_'),
F.col(col_name2))) # The "_#_@_#_" acts as a separator between the values.
fractions1 = input_df.select(merged_col_name).distinct().withColumn("fraction",
F.lit(train_frac)).rdd.collectAsMap()
train_df = input_df.stat.sampleBy(merged_col_name, fractions1, seed_value)
valid_and_test_df = input_df.exceptAll(train_df)
fractions2 = {key: 0.5 for key, value in fractions1.items()} # 0.5 for equal split of valid and test set
valid_df = valid_and_test_df.stat.sampleBy(merged_col_name, fractions2, seed_value)
test_df = valid_and_test_df.exceptAll(valid_df)
# Delete the merged_col_name from all splits
train_df = train_df.drop(merged_col_name)
valid_df = valid_df.drop(merged_col_name)
test_df = test_df.drop(merged_col_name)
return train_df, valid_df, test_df

- 969
- 2
- 15
- 33