Use window functions:
from pyspark.sql.functions import col, sum, monotonically_increasing_id
from pyspark.sql.window import Window
df = spark.createDataFrame(
[(1, 2), (1, 3), (1, 1), (1, 9), (2, 1), (2, 6), (2, 8), (2, 1)],
("id", "val")
)
You'll need Window
like this:
w = (Window.partitionBy("id")
.orderBy("_id")
.rowsBetween(1, 2))
add _id
:
(df
.withColumn("_id", monotonically_increasing_id())
.withColumn("sum_val", sum("val").over(w))
.na.fill(0)
.show())
# +---+---+-----------+-------+
# | id|val| _id|sum_val|
# +---+---+-----------+-------+
# | 1| 2| 0| 4|
# | 1| 3| 1| 10|
# | 1| 1| 8589934592| 9|
# | 1| 9| 8589934593| 0|
# | 2| 1|17179869184| 14|
# | 2| 6|17179869185| 9|
# | 2| 8|25769803776| 1|
# | 2| 1|25769803777| 0|
# +---+---+-----------+-------+
Please beware that monotonically_increasing_id
like this is not a good practice - in production you should always have ordering information embedded in the data itself, and never depend on the internal order of the DataFrame
.