11

I have a dataframe in Spark 2 as shown below where users have between 50 to thousands of posts. I would like to create a new dataframe that will have all the users in the original dataframe but with only 5 randomly sampled posts for each user.

+--------+--------------+--------------------+
| user_id|       post_id|                text|
+--------+--------------+--------------------+
|67778705|44783131591473|some text...........|
|67778705|44783134580755|some text...........|
|67778705|44783136367108|some text...........|
|67778705|44783136970669|some text...........|
|67778705|44783138143396|some text...........|
|67778705|44783155162624|some text...........|
|67778705|44783688650554|some text...........|
|68950272|88655645825660|some text...........|
|68950272|88651393135293|some text...........|
|68950272|88652615409812|some text...........|
|68950272|88655744880460|some text...........|
|68950272|88658059871568|some text...........|
|68950272|88656994832475|some text...........|
+--------+--------------+--------------------+

Something like posts.groupby('user_id').agg(sample('post_id')) but there is no such function in pyspark.

Any advice?

Update:

This question is different from another closely related question stratified-sampling-in-spark in two ways:

  1. It asks about disproportionate stratified sampling rather than the common proportionate method in the other question above.
  2. It asks about doing this in Spark's Dataframe API rather than RDD.

I have also updated the question's title to clarify this.

Community
  • 1
  • 1
Majid Alfifi
  • 568
  • 2
  • 5
  • 18

2 Answers2

12

You can use the .sampleBy(...) method for DataFrames http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy

Here's a working example:

import numpy as np
import string
import random

# generate some fake data
p = [(
    str(int(e)), 
    ''.join(
        random.choice(
            string.ascii_uppercase + string.digits) 
        for _ in range(10)
    )
) for e in np.random.normal(10, 1, 10000)]

posts = spark.createDataFrame(p, ['label', 'val'])

# define the sample size
percent_back = 0.05

# use this if you want an (almost) exact number of samples
# sample_count = 200
# percent_back = sample_count / posts.count()

frac = dict(
    (e.label, percent_back) 
    for e 
    in posts.select('label').distinct().collect()
)

# use this if you want (almost) balanced sample
# f = posts.groupby('label').count()

# f_min_count can also be specified to be exact number 

# e.g. f_min_count = 5

# as long as it is less the the minimum count of posts per user
# calculated from all the users

# alternatively, you can take the minimum post count
# f_min_count = f.select('count').agg(func.min('count').alias('minVal')).collect()[0].minVal

# f = f.withColumn('frac',f_min_count/func.col('count'))

# frac = dict(f.select('label', 'frac').collect())

# sample the data
sampled = posts.sampleBy('label', fractions=frac)

# compare the original counts with sampled
original_total_count = posts.count()
original_counts = posts.groupby('label').count()
original_counts = original_counts \
    .withColumn('count_perc', 
                original_counts['count'] / original_total_count)

sampled_total_count = sampled.count()
sampled_counts = sampled.groupBy('label').count()
sampled_counts = sampled_counts \
    .withColumn('count_perc', 
                sampled_counts['count'] / sampled_total_count)


print(original_counts.sort('label').show(100))
print(sampled_counts.sort('label').show(100))

print(sampled_total_count)
print(sampled_total_count / original_total_count)
TDrabas
  • 858
  • 6
  • 13
  • Nice example! Would you please clarify how to get exact number of samples? I ran the commented out code above but I still get different sample sizes. – Majid Alfifi Jan 07 '17 at 21:50
  • You will not get an exact number but something close to it. It's just a different way of defining the percentages. In the case above it will evaluate to 200/10000 = 2%. – TDrabas Jan 07 '17 at 21:56
  • I was able to get samples close enough to the sample size I needed by defining fractions as follows: frac = posts.groupby('label').count().withColumn('frac',sample_count/F.col('count').toPandas().set_index('label')['frac'].to_dict() – Majid Alfifi Jan 07 '17 at 22:28
  • If you could incorporate the above modification in your answer so I can go ahead and accept it. – Majid Alfifi Jan 07 '17 at 22:40
  • 1
    How is that going to work? `sample_count = 200` and you divide it by the count for each `label`. For instance, `label = 6` would have ~10 observations. Your function then evaluates to 20 and that is something you cannot pass as `fractions` to the `.sampleBy(...)` method. What is more, what you would get in return would not be a stratified sample i.e. a sample with the same proportions of label values as the original dataset. The `.sampleBy(...)` method, under the hood, runs `n` (where `n` is the number of `val`ues in the `label`) uniform sampling from all the records where `label == val`. – TDrabas Jan 07 '17 at 22:52
  • In your code, you pass the same fraction to all the labels but this will still give me larger sample sizes for labels with larger number of values. To account for that I give smaller ratios for labels that have big number of values and bigger ratios for labels that are underrepresented. I still only pass the ratios dict to the `sampleBy` as usual. – Majid Alfifi Jan 07 '17 at 23:08
  • So, in other words, you do not want a stratified sample but rather a balanced one. Just expect that will not be equal to the number of samples you set as it will be bound by the count of the smallest value. Therefore, this is what you want: `f = posts.groupby('label').count(); f_min_count = f.select('count').agg(func.min('count').alias('minVal')).collect(); f = f.withColumn('frac',f_min_count[0].minVal/func.col('count')) frac = dict(f.select('label', 'frac').collect())`. This will return somewhat balanced sample (the counts will most likely not be equal as we're dealing with proportions) – TDrabas Jan 07 '17 at 23:25
  • Also, if you only want a random 5 posts from each user, change the `f_min_count[0].minVal` to 5. Note, that since this is random sampling and you can only specify the proportions to the `sampleBy(...)` method. Also, a word of caution, the random number of posts requested for each user has to be less the the minimum count of posts from all the users. – TDrabas Jan 07 '17 at 23:29
  • I didn't need to address this because I had at least 50 posts for any user but your solution is more generic. I changed my sample to 20 and running this 10 times I got between 10 and 33 which is not perfect but is good enough for my experiment. – Majid Alfifi Jan 08 '17 at 00:03
  • 1
    I have to get 10% of each label in test df so I made a dict with all the class label as key and 0.1 as their values like `{0: 0.1, 1:0.1, 2: 0.1, ..., 20:0.1}`. I supplied this dict as fraction argument to `sampleBy()` function which gives me a sample dataframe then I use unique ID to filter main df to get test df. However after that there is difference between `test_df.count()` and `len(test_df.collect())` every time this difference is arbitary. Can you help what is wrong here? @TDrabas – Aditya Jul 09 '18 at 12:02
11

Using sampleBy will result in approximate solution. Here is an alternative approach that is a little more hacky than the approach above but always results in exactly the same sample sizes.

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

df.withColumn("row_num",row_number().over(Window.partitionBy($"user_id").orderBy($"something_random"))

If you don't already have a random ID then you can use org.apache.spark.sql.functions.rand to create a column with a random value to guarantee your random sampling.

BushMinusZero
  • 1,202
  • 16
  • 21
  • Nice elegant approach! I am choosing this as an answer. (not sure if there are performance issues though) – Majid Alfifi Sep 08 '18 at 19:40
  • you can filter on "row_num" to get the number of samples you want as follows: df.withColumn("row_num",row_number().over(Window.partitionBy($"user_id").orderBy($"something_random")).where(col("row_num") <= 5) – Balázs Fehér Sep 19 '18 at 13:53