24

I have a Spark DataFrame that has one column that has lots of zeros and very few ones (only 0.01% of ones).

I'd like to take a random subsample but a stratified one - so that it keeps the ratio of 1s to 0s in that column.

Is it possible to do in pyspark ?

I am looking for a non-scala solution and on based on DataFrames and not RDD-based.

eliasah
  • 39,588
  • 11
  • 124
  • 154
user3245256
  • 1,842
  • 4
  • 24
  • 51

7 Answers7

45

The solution I suggested in Stratified sampling in Spark is pretty straightforward to convert from Scala to Python (or even to Java - What's the easiest way to stratify a Spark Dataset ?).

Nevertheless, I'll rewrite it python. Let's start first by creating a toy DataFrame :

from pyspark.sql.functions import lit
list = [(2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),(2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),(2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)]
df = spark.createDataFrame(list, ["x1","x2","x3"])
df.show()
# +----------+----------+---+
# |        x1|        x2| x3|
# +----------+----------+---+
# |2147481832|  23355149|  1|
# |2147481832| 973010692|  1|
# |2147481832|2134870842|  1|
# |2147481832| 541023347|  1|
# |2147481832|1682206630|  1|
# |2147481832|1138211459|  1|
# |2147481832| 852202566|  1|
# |2147481832| 201375938|  1|
# |2147481832| 486538879|  1|
# |2147481832| 919187908|  1|
# | 214748183| 919187908|  1|
# | 214748183|  91187908|  1|
# +----------+----------+---+

This DataFrame has 12 elements as you can see :

df.count()
# 12

Distributed as followed :

df.groupBy("x1").count().show()
# +----------+-----+
# |        x1|count|
# +----------+-----+
# |2147481832|   10|
# | 214748183|    2|
# +----------+-----+

Now let's sample :

First we'll set the seed :

seed = 12

The find the keys to fraction on and sample :

fractions = df.select("x1").distinct().withColumn("fraction", lit(0.8)).rdd.collectAsMap()
print(fractions)                                                            
# {2147481832: 0.8, 214748183: 0.8}
sampled_df = df.stat.sampleBy("x1", fractions, seed)
sampled_df.show()
# +----------+---------+---+
# |        x1|       x2| x3|
# +----------+---------+---+
# |2147481832| 23355149|  1|
# |2147481832|973010692|  1|
# |2147481832|541023347|  1|
# |2147481832|852202566|  1|
# |2147481832|201375938|  1|
# |2147481832|486538879|  1|
# |2147481832|919187908|  1|
# | 214748183|919187908|  1|
# | 214748183| 91187908|  1|
# +----------+---------+---+

We can now check the content of our sample :

sampled_df.count()
# 9

sampled_df.groupBy("x1").count().show()
# +----------+-----+
# |        x1|count|
# +----------+-----+
# |2147481832|    7|
# | 214748183|    2|
# +----------+-----+
eliasah
  • 39,588
  • 11
  • 124
  • 154
  • 2
    @eliasah is there any way to add 0.8 and 0.2 fractions? I want to use 0.8 as training set and the other 0.2 as test set. I tried to get the 0.8 using this approach but have difficulties getting the other 0.2 in spark 1.6 where there is no sub query support – Rio Aug 01 '18 at 19:13
  • You can always use `except` on the main DF and the sampled DF @EmmaNej – eliasah Aug 02 '18 at 07:47
  • @eliasah Yeah but that would take a long time considering that I have 20 Million records and no unique key in the dataset. – Rio Aug 02 '18 at 14:10
  • 1
    @EmmaNej Then https://stackoverflow.com/questions/39887526/filter-spark-dataframe-based-on-another-dataframe-that-specifies-blacklist-crite/39889263#39889263 – eliasah Aug 02 '18 at 14:18
  • @eliasah unfortunately Spark 1.6 does not support left_anti join. – Rio Aug 02 '18 at 14:32
  • @eliasah I get that 0.8 here is the variable sampling rates for different keys. the question is how can I choose this value for 12 different keys or classes – Emna Jaoua Jul 11 '19 at 10:51
  • It already take into account different classes/keys. We need a stratified sample according to the column `x1`, so we create a map of keys and what is the fraction of the key we need and we ask spark to give fractionned data by key. – eliasah Jul 11 '19 at 12:33
  • so the choice of the 0.8 value is determined by maximum_value(df.groupBy("x1").count().show())/df.count(). is that correct ? – Emna Jaoua Jul 11 '19 at 12:52
  • No. I've decided that I need 80% of my data stratified. Thus the 0.8 i.e `lit(0.8)`. If I needed 70%, I'd have used 0.7... I advice you to read my answer for the scala API (https://stackoverflow.com/questions/32238727/stratified-sampling-in-spark/32241887#32241887) The explanation is more complete there. – eliasah Jul 11 '19 at 13:02
  • ah okey, I see. Thank you for the clarifícation :) – Emna Jaoua Jul 11 '19 at 15:42
24

Assume you have titanic dataset in 'data' dataframe which you want to split into train and test set using stratified sampling based on the 'Survived' target variable.

  # Check initial distributions of 0's and 1's
-> data.groupBy("Survived").count().show()

 Survived|count|
 +--------+-----+
 |       1|  342|
 |       0|  549


  # Taking 70% of both 0's and 1's into training set
-> train = data.sampleBy("Survived", fractions={0: 0.7, 1: 0.7}, seed=10)

  # Subtracting 'train' from original 'data' to get test set 
-> test = data.subtract(train)



  # Checking distributions of 0's and 1's in train and test sets after the sampling
-> train.groupBy("Survived").count().show()
+--------+-----+
|Survived|count|
+--------+-----+
|       1|  239|
|       0|  399|
+--------+-----+
-> test.groupBy("Survived").count().show()

+--------+-----+
|Survived|count|
+--------+-----+
|       1|  103|
|       0|  150|
+--------+-----+
Ankit Sharma
  • 293
  • 3
  • 6
  • Code may look cleaner and smaller but `subtract` function takes very long to run. Answer shared below is faster : https://stackoverflow.com/a/61016937/8357062 – MSS Dec 12 '22 at 07:06
7

This can be accomplished pretty easily with 'randomSplit' and 'union' in PySpark.

# read in data
df = spark.read.csv(file, header=True)
# split dataframes between 0s and 1s
zeros = df.filter(df["Target"]==0)
ones = df.filter(df["Target"]==1)
# split datasets into training and testing
train0, test0 = zeros.randomSplit([0.8,0.2], seed=1234)
train1, test1 = ones.randomSplit([0.8,0.2], seed=1234)
# stack datasets back together
train = train0.union(train1)
test = test0.union(test1)
yeamusic21
  • 276
  • 3
  • 11
5

this is based on the accepted answer of @eliasah and this so thread

If you want to get back a train and testset you can use the following function:

from pyspark.sql import functions as F 

def stratified_split_train_test(df, frac, label, join_on, seed=42):
    """ stratfied split of a dataframe in train and test set.
    inspiration gotten from:
    https://stackoverflow.com/a/47672336/1771155
    https://stackoverflow.com/a/39889263/1771155"""
    fractions = df.select(label).distinct().withColumn("fraction", F.lit(frac)).rdd.collectAsMap()
    df_frac = df.stat.sampleBy(label, fractions, seed)
    df_remaining = df.join(df_frac, on=join_on, how="left_anti")
    return df_frac, df_remaining

to create a stratified train and test set where 80% of the total is used for the training set

df_train, df_test = stratified_split_train_test(df=df, frac=0.8, label="y", join_on="unique_id")
Vincent Claes
  • 3,960
  • 3
  • 44
  • 62
  • I do not have a 'unique_id' columns in my dataset. Is there a way this function can be re-written? – Behzad Rowshanravan Aug 31 '20 at 20:12
  • @BehzadRowshanravan https://stackoverflow.com/a/72379515/8231475 – Shahidur May 25 '22 at 14:29
  • @BehzadRowshanravan I'm passing my `df` to this function and expect to return `df_train, df_val, df_test` but surprisingly I can't `display(df_train)`. due to this error: *NameError: name 'df_train' is not defined*. Why it is the case in databricks? – Mario Aug 18 '22 at 18:27
3

You can use the below function. I used the other answers to combine.

import pyspark.sql.functions as f
from pyspark.sql import DataFrame as SparkDataFrame


def train_test_split_pyspark(
    df: SparkDataFrame,
    startify_column: str,
    unique_col: str = None,
    train_fraction: float = 0.05,
    validation_fraction: float = 0.005,
    test_fraction: float = 0.005,
    seed: int = 1234,
    to_pandas: bool = True,
):
    if not unique_col:
        unique_col = "any_unique_name_here"
        df = df.withColumn(unique_col, f.monotonically_increasing_id())

    # Train data
    train_fraction_dict = (
        df.select(startify_column)
        .distinct()
        .withColumn("fraction", f.lit(train_fraction))
        .rdd.collectAsMap()
    )
    df_train = df.stat.sampleBy(startify_column, train_fraction_dict, seed)
    df_remaining = df.join(df_train, on=unique_col, how="left_anti")

    # Validation data
    validation_fraction_dict = {
        key: validation_fraction for (_, key) in enumerate(train_fraction_dict)
    }
    df_val = df_remaining.stat.sampleBy(startify_column, validation_fraction_dict, seed)
    df_remaining = df_remaining.join(df_val, on=unique_col, how="left_anti")

    # Test data
    test_fraction_dict = {
        key: test_fraction for (_, key) in enumerate(train_fraction_dict)
    }
    df_test = df_remaining.stat.sampleBy(startify_column, test_fraction_dict, seed)

    if unique_col == "any_unique_name_here":
        df_train = df_train.drop(unique_col)
        df_val = df_val.drop(unique_col)
        df_test = df_test.drop(unique_col)

    if to_pandas:
        return (df_train.toPandas(), df_val.toPandas(), df_test.toPandas())
    return df_train, df_val, df_test
Shahidur
  • 311
  • 2
  • 6
  • I'm passing my `df` to this function and expect to return `df_train, df_val, df_test` but surprisingly I can't `display(df_train)`. due to this error: *NameError: name 'df_train' is not defined*. Why it is the case in databricks? – Mario Aug 18 '22 at 18:26
  • Hi @Mario I assume you deactivated the `to_pandas` variable. Do you get the return as a tuple of three items? – Shahidur Aug 25 '22 at 18:30
0

To avoid rows found in both train/test split or disappearing, I would further add to Vincent Claes’s solution

def stratifiedSampler(sparkDf:DataFrame, ratio:float, 
                     label:str, joinOn:str, seed=42):

        fractions = (sparkDf.select(label).distinct()
                           .withColumn("fraction",f.lit(ratio))
                           .rdd.collectAsMap())

        fracDf = sparkDf.stat.sampleBy(label, fractions, seed)
        fracDf = fracDf.localCheckpoint()
        
        remaingDf = sparkDf.join(fracDf, on=joinOn, how="left_anti")
        return (fracDf, remaingDf)
Kay
  • 1
  • 2
0
from pyspark.sql.functions import lit

list = [(2147481832,23355149,'v0'),(2147481832,973010692,'v3'), 
(2147481832,2134870842,'v1'),(2147481832,541023347,'v3'), 
(2147481832,1682206630,'v2'),(2147481832,1138211459,'v4'), 
(2147481832,852202566,'v2'),(2147481832,201375938,'v5'), 
(2147481832,486538879,'v3'),(2147481832,919187908,'v4'), 
(214748183,919187908,'v3'),(214748183,91187908,'v4')]

df = spark.createDataFrame(list, ["x1","x2","x3"])

df = df.sampleBy("x3", fractions={'v1': 0.2, 'v2': 
0.2, 'v3': 0.2,'v4':0.2,'v5':0.2}, seed=0)
KYS
  • 11
  • 2