I have an Apache Spark Dataframe of the following format
| ID | groupId | phaseName |
|----|-----------|-----------|
| 10 | someHash1 | PhaseA |
| 11 | someHash1 | PhaseB |
| 12 | someHash1 | PhaseB |
| 13 | someHash2 | PhaseX |
| 14 | someHash2 | PhaseY |
Each row represents a phase that happens in a procedure that consists of several of these phases. The ID
column represents a sequential order of phases and the groupId
column shows which phases belong together.
I want to add a new column to the dataframe: previousPhaseName. This column should indicate the previous different phase from the same procedure. The first phase of a process (the one with the minimum ID) will have null
as previous phase. When a phase occurs twice or more, the second (third...) occurrence will have the same previousPhaseName For example:
df =
| ID | groupId | phaseName | prevPhaseName |
|----|-----------|-----------|---------------|
| 10 | someHash1 | PhaseA | null |
| 11 | someHash1 | PhaseB | PhaseA |
| 12 | someHash1 | PhaseB | PhaseA |
| 13 | someHash2 | PhaseX | null |
| 14 | someHash2 | PhaseY | PhaseX |
I am not sure how to implement this. My first approach would be:
- create a second empty dataframe df2
- for each row in df:
find the row with groupId = row.groupId, ID < row.ID, and maximum id - add this row to df2
- join df1 and df2
Partial Solution using Window Functions
I used Window Functions
to aggregate the Name of the previous phase, the number of previous occurrences (not necessarily in a row) of the current phase in the group and the information whether the current and previous phase names are equal:
WindowSpec windowSpecPrev = Window
.partitionBy(df.col("groupId"))
.orderBy(df.col("ID"));
WindowSpec windowSpecCount = Window
.partitionBy(df.col("groupId"), df.col("phaseName"))
.orderBy(df.col("ID"))
.rowsBetween(Long.MIN_VALUE, 0);
df
.withColumn("prevPhase", functions.lag("phaseName", 1).over(windowSpecPrev))
.withColumn("phaseCount", functions.count("phaseId").over(windowSpecCount))
.withColumn("prevSame", when(col("prevPhase").equalTo(col("phaseName")),1).otherwise(0))
df =
| ID | groupId | phaseName | prevPhase | phaseCount | prevSame |
|----|-----------|-----------|-------------|------------|----------|
| 10 | someHash1 | PhaseA | null | 1 | 0 |
| 11 | someHash1 | PhaseB | PhaseA | 1 | 0 |
| 12 | someHash1 | PhaseB | PhaseB | 2 | 1 |
| 13 | someHash2 | PhaseX | null | 1 | 0 |
| 14 | someHash2 | PhaseY | PhaseX | 1 | 0 |
This is still not what I wanted to achieve but good enough for now
Further Ideas
To get the the name of the previous distinct phase I see three possibilities that I have not investigated thoroughly:
- Implement an own
lag
function that does not take an offset but recursively checks the previous line until it finds a value that is different from the given line. (Though I don't think it's possible to use own analytic window functions in Spark SQL) - Find a way to dynamically set the offset of the
lag
function according to the value ofphaseCount
. (That may fail if the previous occurrences of the same phaseName do not appear in a single sequence) - Use a
UserDefinedAggregateFunction
over the window that stores the ID and phaseName of the first given input and seeks for the highest ID with different phaseName.