I am having difficulty implementing this existing answer: PySpark - get row number for each row in a group
Consider the following:
# create df
df = spark.createDataFrame(sc.parallelize([
[1, 'A', 20220722, 1],
[1, 'A', 20220723, 1],
[1, 'B', 20220724, 2],
[2, 'B', 20220722, 1],
[2, 'C', 20220723, 2],
[2, 'B', 20220724, 3],
]),
['ID', 'State', 'Time', 'Expected'])
# rank
w = Window.partitionBy('State').orderBy('ID', 'Time')
df = df.withColumn('rn', F.row_number().over(w))
df = df.withColumn('rank', F.rank().over(w))
df = df.withColumn('dense', F.dense_rank().over(w))
# view
df.show()
+---+-----+--------+--------+---+----+-----+
| ID|State| Time|Expected| rn|rank|dense|
+---+-----+--------+--------+---+----+-----+
| 1| A|20220722| 1| 1| 1| 1|
| 1| A|20220723| 1| 2| 2| 2|
| 1| B|20220724| 2| 1| 1| 1|
| 2| B|20220722| 1| 2| 2| 2|
| 2| B|20220724| 3| 3| 3| 3|
| 2| C|20220723| 2| 1| 1| 1|
+---+-----+--------+--------+---+----+-----+
How can I get the expected value and also sort the dates correctly such that they are ascending?