This is quite late, but today I tried to implement the cte recursive query using PySpark SQL.
Here, I have this simple dataframe. What I want to do is to find the NEWEST ID of each ID.
The original dataframe:
+-----+-----+
|OldID|NewID|
+-----+-----+
| 1| 2|
| 2| 3|
| 3| 4|
| 4| 5|
| 6| 7|
| 7| 8|
| 9| 10|
+-----+-----+
The result I want:
+-----+-----+
|OldID|NewID|
+-----+-----+
| 1| 5|
| 2| 5|
| 3| 5|
| 4| 5|
| 6| 8|
| 7| 8|
| 9| 10|
+-----+-----+
Here is my code:
df = sqlContext.createDataFrame([(1, 2), (2, 3), (3, 4), (4, 5), (6, 7), (7, 8),(9, 10)], "OldID integer,NewID integer").checkpoint().cache()
dfcheck = df.drop('NewID')
dfdistinctID = df.select('NewID').distinct()
dfidfinal = dfdistinctID.join(dfcheck, [dfcheck.OldID == dfdistinctID.NewID], how="left_anti") #We find the IDs that have not been replaced
dfcurrent = df.join(dfidfinal, [dfidfinal.NewID == df.NewID], how="left_semi").checkpoint().cache() #We find the the rows that are related to the IDs that have not been replaced, then assign them to the dfcurrent dataframe.
dfresult = dfcurrent
dfdifferentalias = df.select(df.OldID.alias('id1'), df.NewID.alias('id2')).checkpoint().cache()
while dfcurrent.count() > 0:
dfcurrent = dfcurrent.join(broadcast(dfdifferentalias), [dfcurrent.OldID == dfdifferentalias.id2], how="inner").select(dfdifferentalias.id1.alias('OldID'), dfcurrent.NewID.alias('NewID')).cache()
dfresult = dfresult.unionAll(dfcurrent)
display(dfresult.orderBy('OldID'))
Databricks notebook screenshot
I know that the performance is quite bad, but at least, it give the answer I need.
This is the first time that I post an answer to StackOverFlow, so forgive me if I made any mistake.