The following solution uses PySpark SQL functions to implement the logic requested above.
Set-Up
Create a DataFrame to mimic the example provided
df = spark.createDataFrame(
[('a',),
('c',),
('c',),
('d',),
('b',),
('a',),
('a',),
('d',),
('d',),
('c',),
('c',),
('b',),
('a',),
('b',),],
['id',])
Output
+---+
|id |
+---+
|a |
|c |
|c |
|d |
|b |
|a |
|a |
|d |
|d |
|c |
|c |
|b |
|a |
|b |
+---+
Logic
- Calculate row number (reference logic for row_num here)
df = df.withColumn("row_num", F.row_number().over(Window.orderBy(F.monotonically_increasing_id())))
- Use row number to determine the preceding id value (the lag). There is no preceding id for the first row so the lag results in a null - set this missing value to "c".
df = df.withColumn("lag_id", F.lag("id",1).over(Window.orderBy("row_num")))
df = df.na.fill(value="c", subset=['lag_id'])
output
+---+--------------+------+
|id | row_num |lag_id|
+---+--------------+------+
|a |1 |c |
|c |2 |a |
|c |3 |c |
|d |4 |c |
|b |5 |d |
|a |6 |b |
|a |7 |a |
|d |8 |a |
|d |9 |d |
|c |10 |d |
|c |11 |c |
|b |12 |c |
|a |13 |b |
|b |14 |a |
+---+--------------+------+
- Determine order (
sequence
) for rows that immediately follow a row where id = "c"
df_sequence = df.filter((df.id != "c") & (df.lag_id == "c"))
df_sequence = df_sequence.withColumn("sequence", F.row_number().over(Window.orderBy("row_num")))
output
+---+--------------+------+--------+
|id | row_num |lag_id|sequence|
+---+--------------+------+--------+
|a |1 |c |1 |
|d |4 |c |2 |
|b |12 |c |3 |
+---+--------------+------+--------+
- Join the sequence DF to the original DF
df_joined = df.alias("df1").join(df_sequence.alias("df2"),
on="row_num",
how="leftouter")\
.select(df["*"],df_sequence["sequence"])
)
- Set
sequence
to 0 when id
= "c"
df_joined = df_joined.withColumn('sequence', F.when(df_joined.id == "c", 0)
.otherwise(df_joined.sequence)
output
+---+--------------+------+--------+
|id | row_num |lag_id|sequence|
+---+--------------+------+--------+
|a |1 |c |1 |
|c |2 |a |0 |
|c |3 |c |0 |
|d |4 |c |2 |
|b |5 |d |null |
|a |6 |b |null |
|a |7 |a |null |
|d |8 |a |null |
|d |9 |d |null |
|c |10 |d |0 |
|c |11 |c |0 |
|b |12 |c |3 |
|a |13 |b |null |
|b |14 |a |null |
+---+--------------+------+--------+
- Forward fill sequence values (reference the forward fill logic here)
df_final = df_joined.withColumn('sequence', F.last('sequence', ignorenulls=True).over(Window.orderBy("row_num")
Final Output
+---+--------------+------+--------+
|id | row_num |lag_id|sequence|
+---+--------------+------+--------+
|a |1 |c |1 |
|c |2 |a |0 |
|c |3 |c |0 |
|d |4 |c |2 |
|b |5 |d |2 |
|a |6 |b |2 |
|a |7 |a |2 |
|d |8 |a |2 |
|d |9 |d |2 |
|c |10 |d |0 |
|c |11 |c |0 |
|b |12 |c |3 |
|a |13 |b |3 |
|b |14 |a |3 |
+---+--------------+------+--------+