I recently had to implement a similar aggregation and my first attempt was to use Pandas UDF with Sliding windows. The performance was quite bad and I managed to improve it by using the following approach.
Try to use collect_list
to compose the sliding window vectors and then map them with your UDF. Note that this only works if your sliding window can fit into a workers memory (which usually does).
Here is my test code. The first part is just your code but as a complete reproducible example.
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql.functions import pandas_udf, PandasUDFType, udf
from pyspark.sql.types import FloatType, StructType, StructField, IntegerType, StringType
df = spark.createDataFrame(
[(1, "2021-04-01", 10, -30),
(1, "2021-03-01", 10, 20),
(1, "2021-02-01", 10, -1),
(1, "2021-01-01", 10, 10),
(1, "2020-12-01", 10, 5),
(1, "2021-04-01", 20, -5),
(1, "2021-03-01", 20, -4),
(1, "2021-02-01", 20, -3),
(2, "2021-03-01", 10, 5),
(2, "2021-02-01", 10, 6),
],
StructType([
StructField("csecid", StringType(), True),
StructField("date", StringType(), True),
StructField("analystid", IntegerType(), True),
StructField("revisions_improved", IntegerType(), True)
]))
### Baseline
@pandas_udf(FloatType(), PandasUDFType.GROUPED_AGG)
def method2(analyst: pd.Series, revisions: pd.Series) -> float:
df = pd.DataFrame({
'analyst': analyst,
'revisions': revisions
})
return df.groupby('analyst').last()['revisions'].sum() / df.groupby('analyst').last()['revisions'].abs().sum()
days = lambda x: x*60*60*24
w = Window.partitionBy('csecid').orderBy(F.col('date').cast('timestamp').cast('long')).rangeBetween(-days(100), 0)
# df.withColumn('new_col', method2(F.col('analystid'), F.col('revisions_improved')).over(w))
Proposed alternative:
### Method 3
from typing import List
@udf(FloatType())
def method3(analyst: List[int], revisions: List[int]) -> float:
df = pd.DataFrame({
'analyst': analyst,
'revisions': revisions
})
return float(df.groupby('analyst').last()['revisions'].sum() / df.groupby('analyst').last()['revisions'].abs().sum())
(df
.withColumn('new_col', method2(F.col('analystid'), F.col('revisions_improved')).over(w))
.withColumn('analyst_win', F.collect_list("analystid").over(w))
.withColumn('revisions_win', F.collect_list("revisions_improved").over(w))
.withColumn('method3', method3(F.collect_list("analystid").over(w),
F.collect_list("revisions_improved").over(w)))
.orderBy("csecid", "date", "analystid")
.show(truncate=False))
Result:
+------+----------+---------+------------------+---------+----------------------------+-----------------------------+---------+
|csecid|date |analystid|revisions_improved|new_col |analyst_win |revisions_win |method3 |
+------+----------+---------+------------------+---------+----------------------------+-----------------------------+---------+
|1 |2020-12-01|10 |5 |1.0 |[10] |[5] |1.0 |
|1 |2021-01-01|10 |10 |1.0 |[10, 10] |[5, 10] |1.0 |
|1 |2021-02-01|10 |-1 |-1.0 |[10, 10, 10, 20] |[5, 10, -1, -3] |-1.0 |
|1 |2021-02-01|20 |-3 |-1.0 |[10, 10, 10, 20] |[5, 10, -1, -3] |-1.0 |
|1 |2021-03-01|10 |20 |0.6666667|[10, 10, 10, 20, 10, 20] |[5, 10, -1, -3, 20, -4] |0.6666667|
|1 |2021-03-01|20 |-4 |0.6666667|[10, 10, 10, 20, 10, 20] |[5, 10, -1, -3, 20, -4] |0.6666667|
|1 |2021-04-01|10 |-30 |-1.0 |[10, 10, 20, 10, 20, 10, 20]|[10, -1, -3, 20, -4, -30, -5]|-1.0 |
|1 |2021-04-01|20 |-5 |-1.0 |[10, 10, 20, 10, 20, 10, 20]|[10, -1, -3, 20, -4, -30, -5]|-1.0 |
|2 |2021-02-01|10 |6 |1.0 |[10] |[6] |1.0 |
|2 |2021-03-01|10 |5 |1.0 |[10, 10] |[6, 5] |1.0 |
+------+----------+---------+------------------+---------+----------------------------+-----------------------------+---------+
analyst_win
and revisions_win
are just there to show how the sliding windows are created and passed into the UDF. They should be removed in production.
Moving the Pandas groupby outside of the UDF would probably improve the performance. Spark can take care of that step. However, I did not question that part because you mentioned the function is not representative of the actual tasks.
Check out the performance in SparkUI, specially the time statistics for the tasks that apply the UDF. If the times are high, try increasing the number of partitions with repartition
so that each tasks does a smaller subset of the data.