Starting from the following spark data frame:
from io import StringIO
import pandas as pd
from pyspark.sql.functions import col
pd_df = pd.read_csv(StringIO("""device_id,read_date,id,count
device_A,2017-08-05,4041,3
device_A,2017-08-06,4041,3
device_A,2017-08-07,4041,4
device_A,2017-08-08,4041,3
device_A,2017-08-09,4041,3
device_A,2017-08-10,4041,1
device_A,2017-08-10,4045,2
device_A,2017-08-11,4045,3
device_A,2017-08-12,4045,3
device_A,2017-08-13,4045,3"""),infer_datetime_format=True, parse_dates=['read_date'])
df = spark.createDataFrame(pd_df).withColumn('read_date', col('read_date').cast('date'))
df.show()
Output:
+--------------+----------+----+-----+
|device_id | read_date| id|count|
+--------------+----------+----+-----+
| device_A|2017-08-05|4041| 3|
| device_A|2017-08-06|4041| 3|
| device_A|2017-08-07|4041| 4|
| device_A|2017-08-08|4041| 3|
| device_A|2017-08-09|4041| 3|
| device_A|2017-08-10|4041| 1|
| device_A|2017-08-10|4045| 2|
| device_A|2017-08-11|4045| 3|
| device_A|2017-08-12|4045| 3|
| device_A|2017-08-13|4045| 3|
+--------------+----------+----+-----+
I would like to find the most frequent id for each (device_id, read_date) combination, over a 3 day rolling window. For each group of rows selected by the time window, I need to find the most frequent id by summing up the counts per id, then return the top id.
Expected Output:
+--------------+----------+----+
|device_id | read_date| id|
+--------------+----------+----+
| device_A|2017-08-05|4041|
| device_A|2017-08-06|4041|
| device_A|2017-08-07|4041|
| device_A|2017-08-08|4041|
| device_A|2017-08-09|4041|
| device_A|2017-08-10|4041|
| device_A|2017-08-11|4045|
| device_A|2017-08-12|4045|
| device_A|2017-08-13|4045|
+--------------+----------+----+
I am starting to think this is only possible using a custom aggregation function. Since spark 2.3 is not out I will have to write this in Scala or use collect_list. Am I missing something?