Suppose we have a very large spark dataframe where we wish to create a lag column:
lag_df = df.withColumn('lag',func.lag(df['id'],1)
.over(Window.partitionBy().orderBy('id')))
+---+----+
| id| lag|
+---+----+
| 1|null|
| 2| 1|
| 3| 2|
| 4| 3|
| 5| 4|
. .
. .
I found that the above ends up running on a single executor. This is fine for small dataframes, but it is not scalable at all. We can't use the paritionBy, so is there a different way to improve the scalability of this task?