You can use pyspark.sql.functions.lead()
and pyspark.sql.functions.lag()
but first you need a way to order your rows. If you don't already have a column that determines the order, you can create one using pyspark.sql.functions.monotonically_increasing_id()
Then use this in conjunction with a Window
function.
For example, if you had the following DataFrame df
:
df.show()
#+---+---+---+---+
#| a| b| c| d|
#+---+---+---+---+
#| 1| 0| 1| 0|
#| 0| 0| 1| 1|
#| 0| 1| 0| 1|
#+---+---+---+---+
You could do:
from pyspark.sql import Window
import pyspark.sql.functions as f
cols = df.columns
df = df.withColumn("id", f.monotonically_increasing_id())
df.select(
"*",
*([f.lag(f.col(c),default=0).over(Window.orderBy("id")).alias("prev_"+c) for c in cols] +
[f.lead(f.col(c),default=0).over(Window.orderBy("id")).alias("next_"+c) for c in cols])
).drop("id").show()
#+---+---+---+---+------+------+------+------+------+------+------+------+
#| a| b| c| d|prev_a|prev_b|prev_c|prev_d|next_a|next_b|next_c|next_d|
#+---+---+---+---+------+------+------+------+------+------+------+------+
#| 1| 0| 1| 0| 0| 0| 0| 0| 0| 0| 1| 1|
#| 0| 0| 1| 1| 1| 0| 1| 0| 0| 1| 0| 1|
#| 0| 1| 0| 1| 0| 0| 1| 1| 0| 0| 0| 0|
#+---+---+---+---+------+------+------+------+------+------+------+------+