I have a pyspark dataframe that looks like (a massively larger version of) the following:
+---+---+----+----+
| id| t|type| val|
+---+---+----+----+
|100| 1| 1| 10|
|100| 2| 0|NULL|
|100| 5| 1| 20|
|100| 8| 0|NULL|
|100| 12| 0|NULL|
|100| 20| 0|NULL|
|100| 22| 1| 30|
|200| 5| 1| 40|
|200| 11| 0|NULL|
|200| 19| 1| 50|
|200| 24| 0|NULL|
|200| 25| 0|NULL|
+---+---+----+----+
I want to make a new column which, for rows with type 1
, uses val
, and for type 0
, uses val
from the most recent entry of type 1
.
The output would look like this:
+---+---+----+----+----+
| id| t|type| val|val2|
+---+---+----+----+----+
|100| 1| 1| 10| 10|
|100| 2| 0|NULL| 10|
|100| 5| 1| 20| 20|
|100| 8| 0|NULL| 20|
|100| 12| 0|NULL| 20|
|100| 20| 0|NULL| 20|
|100| 22| 1| 30| 30|
|200| 5| 1| 40| 40|
|200| 11| 0|NULL| 40|
|200| 19| 1| 50| 50|
|200| 24| 0|NULL| 50|
|200| 25| 0|NULL| 50|
+---+---+----+----+----+
It's fairly straightforward how this could be done by iteration if we were in a pandas dataframe, but I can't figure out a way to do using tools from pyspark. What I'd like to do is
from pyspark.sql import Window
import pyspark.sql.functions as sf
w = Window.partitionBy(['id']).orderBy('t')
df.withColumn('val2',
sf.when(col('type'), col('val')).otherwise(sf.lag(col('val')).over(w))
)
but this yields
+---+---+----+----+----+
| id| t|type| val|val2|
+---+---+----+----+----+
|100| 1| 1| 10| 10|
|100| 2| 0|NULL| 10|
|100| 5| 1| 20| 20|
|100| 8| 0|NULL| 20|
|100| 12| 0|NULL|NULL|
|100| 20| 0|NULL|NULL|
|100| 22| 1| 30| 30|
|200| 5| 1| 40| 40|
|200| 11| 0|NULL| 40|
|200| 19| 1| 50| 50|
|200| 24| 0|NULL| 50|
|200| 25| 0|NULL|NULL|
+---+---+----+----+----+
I understand why this doesn't work, but I'm not sure how to fix it. I think I could use a groupby('id').applyInPandas(...)
with a function that iterates through the rows, but this would be really slow, is there a better way to do it?