1

I have the following data frame in pyspark:

date user_country account_type num_listens
2022-08-01 UK premium 32
2022-08-01 DE free 64
2022-08-01 FR free 93
2022-08-01 UK free 51
2022-08-02 UK premium 26
2022-08-02 FR free 34
2022-08-02 DE free 29
2022-08-02 DE premium 41
2022-08-02 DE free 12
2022-08-02 FR premium 31
2022-08-03 FR free 55
2022-08-03 UK premium 38
2022-08-03 UK premium 51
2022-08-03 FR free 81
2022-08-04 DE free 6
2022-08-04 UK premium 97
2022-08-04 FR free 33
2022-08-04 UK premium 41
2022-08-04 FR premium 67
2022-08-04 DE free 86
2022-08-04 DE free 25
2022-08-04 FR free 16
2022-08-04 FR free 48
2022-08-04 UK premium 11
2022-08-04 UK free 24
2022-08-05 DE free 95
2022-08-05 FR free 68
2022-08-05 DE premium 23
2022-08-05 UK free 79
2022-08-05 UK free 41
2022-08-05 DE premium 99
columns = ["date", "user_country","account_type", "num_listens"]
data = [("2022-08-01", "UK", "premium", "32"),
        ("2022-08-01", "DE", "free", "64"),
        ("2022-08-01", "FR", "free", "93"),
        ("2022-08-01", "UK", "free", "51"),
        ("2022-08-02", "UK", "premium", "26"),
        ("2022-08-02", "FR", "free", "34"),
        ("2022-08-02", "DE", "free", "29"),
        ("2022-08-02", "DE", "premium", "41"),
        ("2022-08-02", "DE", "free", "12"),
        ("2022-08-02", "FR", "premium", "31"),
        ("2022-08-03", "FR", "free", "55"),
        ("2022-08-03", "UK", "premium", "38"),
        ("2022-08-03", "UK", "premium", "51"),
        ("2022-08-03", "FR", "free", "81"),
        ("2022-08-04", "DE", "free", "6"),
        ("2022-08-04", "UK", "premium", "97"),
        ("2022-08-04", "FR", "free", "33"),
        ("2022-08-04", "UK", "premium", "41"),
        ("2022-08-04", "FR", "premium", "67"),
        ("2022-08-04", "DE", "free", "86"),
        ("2022-08-04", "DE", "free", "25"),
        ("2022-08-04", "FR", "free", "16"),
        ("2022-08-04", "FR", "free", "48"),
        ("2022-08-04", "UK", "premium", "11"),
        ("2022-08-04", "UK", "free", "24"),
        ("2022-08-05", "DE", "free", "95"),
        ("2022-08-05", "FR", "free", "68"),
        ("2022-08-05", "DE", "premium", "23"),
        ("2022-08-05", "UK", "free", "79"),
        ("2022-08-05", "UK", "free", "41"),
        ("2022-08-05", "DE", "premium", "99")        
       ]

I'm trying to group this data by user_country, account_type and num_listens, always calculating the median value for each group. On top of this I would like to use a sliding time window to restrict the data I use for each aggregation. For example, when calculating the median value on 2022-08-04, I would only like to use data from the ten dates prior.

The resulting table should look as follows:

snapshot_date user_country account_type median
2022-08-06 UK premium 38
2022-08-06 DE free 29
2022-08-06 FR free 52
2022-08-06 UK free 46
2022-08-06 DE premium 41
2022-08-06 FR premium 49
2022-08-05 UK premium 38
2022-08-05 DE free 27
2022-08-05 FR free 48
2022-08-05 UK free 38
2022-08-05 DE premium 41
2022-08-05 FR premium 49
2022-08-04 UK premium 35
2022-08-04 DE free 29
2022-08-04 FR free 68
2022-08-04 UK free 51
2022-08-04 DE premium 41
2022-08-04 FR premium 31
2022-08-03 UK premium 29
2022-08-03 DE free 29
2022-08-03 FR free 64
2022-08-03 UK free 51
2022-08-03 DE premium 41
2022-08-03 FR premium 31
2022-08-02 UK premium 32
2022-08-02 DE free 64
2022-08-02 FR free 93
2022-08-02 UK free 51

The value in the first row would be the median number of listens for all UK users with the premium account, using data from the previous 10 days (I only included a small sample of 5 days so in this specific case there would not be the full desired rang of 10 days available).

Any help on how this can be achieved in pyspark would be much appreciated. I've been fiddling around with combining a group by with a window function but have been unable to get the desired result.

Manuel
  • 35
  • 4

2 Answers2

1

As there are some missing records in some date in your dataframe (eg 2022-08-03 - DE - free), but you still need to calculate the median of these records. Therefore, I will create a reference table to store all the combination first:

ref_tbl = df\
    .groupBy('user_country', 'account_type')\
    .agg(func.sequence(func.min(func.to_date('date')), func.max(func.to_date('date'))).alias('date_lst'))
ref_tbl = ref_tbl\
    .select(
        func.explode('date_lst').alias('date'),
        'user_country', 'account_type',
        func.lit(0).alias('num_listens')
    )\
    .withColumn('date', func.date_format('date', 'yyyy-MM-dd'))

ref_tbl.show(20, False)
+----------+------------+------------+-----------+
|date      |user_country|account_type|num_listens|
+----------+------------+------------+-----------+
|2022-08-01|UK          |premium     |0          |
|2022-08-02|UK          |premium     |0          |
|2022-08-03|UK          |premium     |0          |
|2022-08-04|UK          |premium     |0          |
|2022-08-01|DE          |free        |0          |
|2022-08-02|DE          |free        |0          |
|2022-08-03|DE          |free        |0          |
|2022-08-04|DE          |free        |0          |
|2022-08-05|DE          |free        |0          |
|2022-08-01|FR          |free        |0          |
|2022-08-02|FR          |free        |0          |
|2022-08-03|FR          |free        |0          |
|2022-08-04|FR          |free        |0          |
|2022-08-05|FR          |free        |0          |
|2022-08-01|UK          |free        |0          |
|2022-08-02|UK          |free        |0          |
|2022-08-03|UK          |free        |0          |
|2022-08-04|UK          |free        |0          |
|2022-08-05|UK          |free        |0          |
|2022-08-02|DE          |premium     |0          |
+----------+------------+------------+-----------+
only showing top 20 rows

Then we can union this reference table back to the main dataframe:

df2 = df\
    .unionByName(ref_tbl)\
    .orderBy(['user_country', 'account_type', 'date'])
df2.show(20, False)
+----------+------------+------------+-----------+
|date      |user_country|account_type|num_listens|
+----------+------------+------------+-----------+
|2022-08-01|DE          |free        |64         |
|2022-08-01|DE          |free        |0          |
|2022-08-02|DE          |free        |0          |
|2022-08-02|DE          |free        |12         |
|2022-08-02|DE          |free        |29         |
|2022-08-03|DE          |free        |0          |
|2022-08-04|DE          |free        |86         |
|2022-08-04|DE          |free        |0          |
|2022-08-04|DE          |free        |25         |
|2022-08-04|DE          |free        |6          |
|2022-08-05|DE          |free        |0          |
|2022-08-05|DE          |free        |95         |
|2022-08-02|DE          |premium     |0          |
|2022-08-02|DE          |premium     |41         |
|2022-08-03|DE          |premium     |0          |
|2022-08-04|DE          |premium     |0          |
|2022-08-05|DE          |premium     |0          |
|2022-08-05|DE          |premium     |23         |
|2022-08-05|DE          |premium     |99         |
|2022-08-01|FR          |free        |93         |
+----------+------------+------------+-----------+
only showing top 20 rows

You question is to collect the the previous 10 day of the records and calculate the median, in fact it equals to take today and previous 9 day of records. You can use a window function to achieve this goal:

window_func = Window\
    .partitionBy('user_country', 'account_type')\
    .orderBy(func.expr("unix_date(to_date(date))"))\
    .rangeBetween(-9, 0)

df3 = df2\
    .select(
        'date', 'user_country', 'account_type',
        func.collect_list(func.when(func.col('num_listens')>0, func.col('num_listens')).otherwise(func.lit(None))).over(window_func).alias('value_lst')
    ).distinct()\
    .orderBy(['date', 'user_country', 'account_type'], ascending=[0, 0, 0])

df3.show(10, False)
+----------+------------+------------+--------------------------------+
|date      |user_country|account_type|value_lst                       |
+----------+------------+------------+--------------------------------+
|2022-08-05|UK          |free        |[51, 24, 79, 41]                |
|2022-08-05|FR          |free        |[93, 34, 55, 81, 33, 16, 48, 68]|
|2022-08-05|DE          |premium     |[41, 23, 99]                    |
|2022-08-05|DE          |free        |[64, 29, 12, 6, 86, 25, 95]     |
|2022-08-04|UK          |premium     |[32, 26, 38, 51, 97, 41, 11]    |
|2022-08-04|UK          |free        |[51, 24]                        |
|2022-08-04|FR          |premium     |[31, 67]                        |
|2022-08-04|FR          |free        |[93, 34, 55, 81, 33, 16, 48]    |
|2022-08-04|DE          |premium     |[41]                            |
|2022-08-04|DE          |free        |[64, 29, 12, 6, 86, 25]         |
+----------+------------+------------+--------------------------------+
only showing top 10 rows

Finally you can calculate the median:

import statistics

df4 = df3\
    .withColumn('median', func.udf(lambda lst: statistics.median([int(value) for value in lst]))(func.col('value_lst')))

df4.show(10, False)
+----------+------------+------------+--------------------------------+------+
|date      |user_country|account_type|value_lst                       |median|
+----------+------------+------------+--------------------------------+------+
|2022-08-05|UK          |free        |[51, 24, 79, 41]                |46.0  |
|2022-08-05|FR          |free        |[93, 34, 55, 81, 33, 16, 48, 68]|51.5  |
|2022-08-05|DE          |premium     |[41, 23, 99]                    |41    |
|2022-08-05|DE          |free        |[64, 29, 12, 6, 86, 25, 95]     |29    |
|2022-08-04|UK          |premium     |[32, 26, 38, 51, 97, 41, 11]    |38    |
|2022-08-04|UK          |free        |[51, 24]                        |37.5  |
|2022-08-04|FR          |premium     |[31, 67]                        |49.0  |
|2022-08-04|FR          |free        |[93, 34, 55, 81, 33, 16, 48]    |48    |
|2022-08-04|DE          |premium     |[41]                            |41    |
|2022-08-04|DE          |free        |[64, 29, 12, 6, 86, 25]         |27.0  |
+----------+------------+------------+--------------------------------+------+
only showing top 10 rows

The reason why I don't use the percentile or approxQuantile in Spark API here is because they are all approximation but not exact (for example if the list is [1, 2], spark percentile will take 1 as median but not 1.5). Therefore you need to use the UDF with external library / create your own logic to achieve your goal, but please remember the performance might be bad if the list is too large.

Jonathan Lam
  • 1,761
  • 2
  • 8
  • 17
1

You can collect the values in an array, and then apply the median logic on that.

For simplicity, I'll calculate the median of a window of 4 dates using your sample data. This is considering you don't want a continuity in the dates, i.e. previous 3 dates may or may not be in sequence ([2022-01-01, 2022-01-03, 2022-01-04, 2022-01-04] is also acceptable).

data_sdf. \
    withColumn('num_listens_arr', 
               func.array_sort(func.collect_list('num_listens').
                               over(wd.partitionBy('user_country', 'account_type').orderBy('date').rowsBetween(-3, 0))
                               )
               ). \
    withColumn('median', 
               func.when(func.size('num_listens_arr') % 2 == 0, 
                         func.expr('(num_listens_arr[int(size(num_listens_arr) / 2)-1] + num_listens_arr[int(size(num_listens_arr) / 2)]) / 2').cast('double')
                         ).
               otherwise(func.expr('num_listens_arr[int(size(num_listens_arr) / 2)]').cast('double'))
               ). \
    show(data_sdf.count())

# +----------+------------+------------+-----------+----------------+------+
# |      date|user_country|account_type|num_listens| num_listens_arr|median|
# +----------+------------+------------+-----------+----------------+------+
# |2022-08-01|          UK|        free|         51|            [51]|  51.0|
# |2022-08-04|          UK|        free|         24|        [24, 51]|  37.5|
# |2022-08-05|          UK|        free|         79|    [24, 51, 79]|  51.0|
# |2022-08-05|          UK|        free|         41|[24, 41, 51, 79]|  46.0|
# |2022-08-01|          UK|     premium|         32|            [32]|  32.0|
# |2022-08-02|          UK|     premium|         26|        [26, 32]|  29.0|
# |2022-08-03|          UK|     premium|         38|    [26, 32, 38]|  32.0|
# |2022-08-03|          UK|     premium|         51|[26, 32, 38, 51]|  35.0|
# |2022-08-04|          UK|     premium|         97|[26, 38, 51, 97]|  44.5|
# |2022-08-04|          UK|     premium|         41|[38, 41, 51, 97]|  46.0|
# |2022-08-04|          UK|     premium|         11|[11, 41, 51, 97]|  46.0|
# |2022-08-02|          DE|     premium|         41|            [41]|  41.0|
# |2022-08-05|          DE|     premium|         23|        [23, 41]|  32.0|
# |2022-08-05|          DE|     premium|         99|    [23, 41, 99]|  41.0|
# |2022-08-01|          DE|        free|         64|            [64]|  64.0|
# |2022-08-02|          DE|        free|         29|        [29, 64]|  46.5|
# |2022-08-02|          DE|        free|         12|    [12, 29, 64]|  29.0|
# |2022-08-04|          DE|        free|          6| [6, 12, 29, 64]|  20.5|
# |2022-08-04|          DE|        free|         86| [6, 12, 29, 86]|  20.5|
# |2022-08-04|          DE|        free|         25| [6, 12, 25, 86]|  18.5|
# |2022-08-05|          DE|        free|         95| [6, 25, 86, 95]|  55.5|
# |2022-08-01|          FR|        free|         93|            [93]|  93.0|
# |2022-08-02|          FR|        free|         34|        [34, 93]|  63.5|
# |2022-08-03|          FR|        free|         55|    [34, 55, 93]|  55.0|
# |2022-08-03|          FR|        free|         81|[34, 55, 81, 93]|  68.0|
# |2022-08-04|          FR|        free|         33|[33, 34, 55, 81]|  44.5|
# |2022-08-04|          FR|        free|         16|[16, 33, 55, 81]|  44.0|
# |2022-08-04|          FR|        free|         48|[16, 33, 48, 81]|  40.5|
# |2022-08-05|          FR|        free|         68|[16, 33, 48, 68]|  40.5|
# |2022-08-02|          FR|     premium|         31|            [31]|  31.0|
# |2022-08-04|          FR|     premium|         67|        [31, 67]|  49.0|
# +----------+------------+------------+-----------+----------------+------+

If you do want to maintain the sequence of dates, you can use a rangeBetween().

data_sdf. \
    withColumn('dt_long', func.col('date').cast('timestamp').cast('long')). \
    withColumn('num_listens_arr', 
               func.array_sort(func.collect_list('num_listens').
                               over(wd.partitionBy('user_country', 'account_type').orderBy('dt_long').rangeBetween(-3*24*60*60, 0))
                               )
               ). \
    withColumn('median', 
               func.when(func.size('num_listens_arr') % 2 == 0, 
                         func.expr('(num_listens_arr[int(size(num_listens_arr) / 2)-1] + num_listens_arr[int(size(num_listens_arr) / 2)]) / 2').cast('double')
                         ).
               otherwise(func.expr('num_listens_arr[int(size(num_listens_arr) / 2)]').cast('double'))
               ). \
    show(data_sdf.count(), truncate=False)

# +----------+------------+------------+-----------+----------+----------------------------+------+
# |date      |user_country|account_type|num_listens|dt_long   |num_listens_arr             |median|
# +----------+------------+------------+-----------+----------+----------------------------+------+
# |2022-08-01|UK          |free        |51         |1659312000|[51]                        |51.0  |
# |2022-08-04|UK          |free        |24         |1659571200|[24, 51]                    |37.5  |
# |2022-08-05|UK          |free        |79         |1659657600|[24, 41, 79]                |41.0  |
# |2022-08-05|UK          |free        |41         |1659657600|[24, 41, 79]                |41.0  |
# |2022-08-01|UK          |premium     |32         |1659312000|[32]                        |32.0  |
# |2022-08-02|UK          |premium     |26         |1659398400|[26, 32]                    |29.0  |
# |2022-08-03|UK          |premium     |38         |1659484800|[26, 32, 38, 51]            |35.0  |
# |2022-08-03|UK          |premium     |51         |1659484800|[26, 32, 38, 51]            |35.0  |
# |2022-08-04|UK          |premium     |97         |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0  |
# |2022-08-04|UK          |premium     |41         |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0  |
# |2022-08-04|UK          |premium     |11         |1659571200|[11, 26, 32, 38, 41, 51, 97]|38.0  |
# |2022-08-02|DE          |premium     |41         |1659398400|[41]                        |41.0  |
# |2022-08-05|DE          |premium     |23         |1659657600|[23, 41, 99]                |41.0  |
# |2022-08-05|DE          |premium     |99         |1659657600|[23, 41, 99]                |41.0  |
# |2022-08-01|DE          |free        |64         |1659312000|[64]                        |64.0  |
# |2022-08-02|DE          |free        |29         |1659398400|[12, 29, 64]                |29.0  |
# |2022-08-02|DE          |free        |12         |1659398400|[12, 29, 64]                |29.0  |
# |2022-08-04|DE          |free        |6          |1659571200|[6, 12, 25, 29, 64, 86]     |27.0  |
# |2022-08-04|DE          |free        |86         |1659571200|[6, 12, 25, 29, 64, 86]     |27.0  |
# |2022-08-04|DE          |free        |25         |1659571200|[6, 12, 25, 29, 64, 86]     |27.0  |
# |2022-08-05|DE          |free        |95         |1659657600|[6, 12, 25, 29, 86, 95]     |27.0  |
# |2022-08-01|FR          |free        |93         |1659312000|[93]                        |93.0  |
# |2022-08-02|FR          |free        |34         |1659398400|[34, 93]                    |63.5  |
# |2022-08-03|FR          |free        |55         |1659484800|[34, 55, 81, 93]            |68.0  |
# |2022-08-03|FR          |free        |81         |1659484800|[34, 55, 81, 93]            |68.0  |
# |2022-08-04|FR          |free        |33         |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0  |
# |2022-08-04|FR          |free        |16         |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0  |
# |2022-08-04|FR          |free        |48         |1659571200|[16, 33, 34, 48, 55, 81, 93]|48.0  |
# |2022-08-05|FR          |free        |68         |1659657600|[16, 33, 34, 48, 55, 68, 81]|48.0  |
# |2022-08-02|FR          |premium     |31         |1659398400|[31]                        |31.0  |
# |2022-08-04|FR          |premium     |67         |1659571200|[31, 67]                    |49.0  |
# +----------+------------+------------+-----------+----------+----------------------------+------+

Coming to the median calculation, if the array of the values has even number of elements, the average of the middle 2 elements should be the resulting median.

  • so, sort the array of values
  • check the size of the array (number of elements)
    • if the number of elements in the array is divisible by 2, calculate average of (size/2)-1th and (size/2)th elements - e.g., if size is 6, arr[2] element and arr[3] element
    • otherwise, just take the middle value as median
samkart
  • 6,007
  • 2
  • 14
  • 29