I have a spark dataframe that looks something like below.
date | ID | window_size | qty |
---|---|---|---|
01/01/2020 | 1 | 2 | 1 |
02/01/2020 | 1 | 2 | 2 |
03/01/2020 | 1 | 2 | 3 |
04/01/2020 | 1 | 2 | 4 |
01/01/2020 | 2 | 3 | 1 |
02/01/2020 | 2 | 3 | 2 |
03/01/2020 | 2 | 3 | 3 |
04/01/2020 | 2 | 3 | 4 |
I'm trying to apply a rolling window of size window_size to each ID in the dataframe and get the rolling sum. Basically I'm calculating a rolling sum (pd.groupby.rolling(window=n).sum()
in pandas) where the window size (n) can change per group.
Expected output
date | ID | window_size | qty | rolling_sum |
---|---|---|---|---|
01/01/2020 | 1 | 2 | 1 | null |
02/01/2020 | 1 | 2 | 2 | 3 |
03/01/2020 | 1 | 2 | 3 | 5 |
04/01/2020 | 1 | 2 | 4 | 7 |
01/01/2020 | 2 | 3 | 1 | null |
02/01/2020 | 2 | 3 | 2 | null |
03/01/2020 | 2 | 3 | 3 | 6 |
04/01/2020 | 2 | 3 | 4 | 9 |
I'm struggling to find a solution that works and is fast enough on a large dataframe (+- 350M rows).
What I have tried
I tried the solution in the below thread:
The idea is to first use sf.collect_list
and then slice the ArrayType
column correctly.
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
However this yields below error:
TypeError: Column is not iterable
I have also tried using sf.expr
like below
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
Which yields:
data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;
I tried manually casting the qty_list
column to ArrayType(IntegerType())
with the same result.
I tried using a UDF but that fails with several out of memory errors after 1,5 hours or so.
Questions
Reading the spark documentation suggests to me that I should be able to pass columns to
sf.slice()
, am I doing something wrong? Where is theTypeError
coming from?Is there a better way to achieve what I want without using
sf.collect_list()
and/orsf.slice()
?If all else fails, what would be the optimal way to do this using a udf? I attempted different versions of the same udf and tried to make sure the udf is the last operation spark has to perform, but all failed.