0

I have a dataframe as follow:

from pyspark.sql import functions as f
from pyspark.sql.window import Window

df = spark.createDataFrame([
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:30:57.000", "Username": "user1", "Region": "US"},
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:31:57.014", "Username": "user2", "Region": "US"},
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:32:57.914", "Username": "user1", "Region": "MX"},
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:35:57.914", "Username": "user2", "Region": "CA"},
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:33:57.914", "Username": "user1", "Region": "UK"},
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:34:57.914", "Username": "user1", "Region": "GR"},
  {"groupId":"A","Day":"2021-01-27", "ts": "2021-01-27 08:36:57.914", "Username": "user2", "Region": "IR"}])

w = Window.partitionBy().orderBy("groupId","Username").orderBy("Username","ts")
df2 = df.withColumn("prev_region", f.lag(df.Region).over(w))
Day Region Username groupId ts
2021-01-27 US user1 A 2021-01-27 08:30:57.000
2021-01-27 MX user1 A 2021-01-27 08:32:57.914
2021-01-27 UK user1 A 2021-01-27 08:33:57.914
2021-01-27 GR user1 A 2021-01-27 08:34:57.914
2021-01-27 US user2 A 2021-01-27 08:31:57.014
2021-01-27 CA user2 A 2021-01-27 08:35:57.914
2021-01-27 IR user2 A 2021-01-27 08:36:57.914

And, I want to know what was previous region of users so I used lag function.

Day Region Username groupId ts prev_region
2021-01-27 US user1 A 2021-01-27 08:30:57.000 null
2021-01-27 MX user1 A 2021-01-27 08:32:57.914 US
2021-01-27 UK user1 A 2021-01-27 08:33:57.914 MX
2021-01-27 GR user1 A 2021-01-27 08:34:57.914 UK
2021-01-27 US user2 A 2021-01-27 08:31:57.014 GR
2021-01-27 CA user2 A 2021-01-27 08:35:57.914 US
2021-01-27 IR user2 A 2021-01-27 08:36:57.914 CA

As you see the value of "prev region" column in the first record of user2 expected to be "null"; however, it is wrong value. I would be thankful if you can show me how to fix it.

user2967251
  • 171
  • 2
  • 2
  • 7

2 Answers2

1

You are almost there.

Simply based on your DataFrame, by specifying the windows function as the following will work.

# Python API
>>> w = Window.partitionBy("Username").orderBy("groupId", "Username", "ts")
>>> df2.show(truncate=100)
+----------+------+--------+-------+-----------------------+-----------+
|       Day|Region|Username|groupId|                     ts|prev_region|
+----------+------+--------+-------+-----------------------+-----------+
|2021-01-27|    US|   user1|      A|2021-01-27 08:30:57.000|       null|
|2021-01-27|    MX|   user1|      A|2021-01-27 08:32:57.914|         US|
|2021-01-27|    UK|   user1|      A|2021-01-27 08:33:57.914|         MX|
|2021-01-27|    GR|   user1|      A|2021-01-27 08:34:57.914|         UK|
|2021-01-27|    US|   user2|      A|2021-01-27 08:31:57.014|       null|
|2021-01-27|    CA|   user2|      A|2021-01-27 08:35:57.914|         US|
|2021-01-27|    IR|   user2|      A|2021-01-27 08:36:57.914|         CA|
+----------+------+--------+-------+-----------------------+-----------+

# SQL API
df.createOrReplaceTempView("df")
result = spark.sql("""
    SELECT 
        Day, Region, Username, groupId, ts, 
        LAG(Region) OVER (PARTITION BY Username ORDER BY groupId, Username, ts) as rank
    FROM df 
    """)
result.show(truncate=100)
+----------+------+--------+-------+-----------------------+----+
|       Day|Region|Username|groupId|                     ts|rank|
+----------+------+--------+-------+-----------------------+----+
|2021-01-27|    US|   user1|      A|2021-01-27 08:30:57.000|null|
|2021-01-27|    MX|   user1|      A|2021-01-27 08:32:57.914|  US|
|2021-01-27|    UK|   user1|      A|2021-01-27 08:33:57.914|  MX|
|2021-01-27|    GR|   user1|      A|2021-01-27 08:34:57.914|  UK|
|2021-01-27|    US|   user2|      A|2021-01-27 08:31:57.014|null|
|2021-01-27|    CA|   user2|      A|2021-01-27 08:35:57.914|  US|
|2021-01-27|    IR|   user2|      A|2021-01-27 08:36:57.914|  CA|
+----------+------+--------+-------+-----------------------+----+

If there are more than one group (multiple groupIds), then state the window function as the following:

>>> w = Window.partitionBy("groupId", "Username").orderBy("groupId", "ts", "Username")
Scott Hsieh
  • 1,339
  • 1
  • 11
  • 25
0

You just need to add your Username column in the partitionByfunction. Also there is no need to have two orderBy function call. Change your line to:

w = Window.partitionBy('Username').orderBy("ts")
Amir Maleki
  • 389
  • 1
  • 2
  • 14