1

I am using a simple groupby query in scala spark where the objective is to get the first value in the group in a sorted dataframe. Here is my spark dataframe

+---------------+------------------------------------------+
|ID             |some_flag |some_type  |  Timestamp        |
+---------------+------------------------------------------+
|      656565654|      true|     Type 1|2018-08-10 00:00:00|
|      656565654|     false|     Type 1|2017-08-02 00:00:00|
|      656565654|     false|     Type 2|2016-07-30 00:00:00|
|      656565654|     false|     Type 2|2016-05-04 00:00:00|
|      656565654|     false|     Type 2|2016-04-29 00:00:00|
|      656565654|     false|     Type 2|2015-10-29 00:00:00|
|      656565654|     false|     Type 2|2015-04-29 00:00:00|
+---------------+----------+-----------+-------------------+

Here is my aggregate query

val sampleDF = df.sort($"Timestamp".desc).groupBy("ID").agg(first("Timestamp"), first("some_flag"), first("some_type"))

The expected result is

+---------------+-------------+---------+-------------------+
|ID             |some_falg    |some_type|  Timestamp        |
+---------------+-------------+---------+-------------------+
|      656565654|         true|   Type 1|2018-08-10 00:00:00|
+---------------+-------------+---------+-------------------+

But getting following wierd output and it keeps changing like a random row

+---------------+-------------+---------+-------------------+
|ID             |some_falg    |some_type|  Timestamp        |
+---------------+-------------+---------+-------------------+
|      656565654|        false|   Type 2|2015-10-29 00:00:00|
+---------------+-------------+---------+-------------------+

Also please note that there are no nulls in the dataframe. I am scratching me head where I am doing something wrong. Need help!

muazfaiz
  • 4,611
  • 14
  • 50
  • 88

2 Answers2

3

The way you are trying to get all the first values returns an incorrect result. Each column value might be from a different row.

Instead you should only order by timestamp in the descending order per group and get the first row. An easy way to do it is using a function like row_number.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val sampleDF = df.withColumn("rnum",row_number().over(Window.partitionBy(col("ID")).orderBy(col("Timestamp").desc)))

sampleDF.filter(col("rnum") == 1).show
Vamsi Prabhala
  • 48,685
  • 4
  • 36
  • 58
3

Just to add to Vamsi's answer; the problem is that the values in a groupBy result group aren't returned in any particular order (particularly given the distributed nature of Spark operations), so the first function is perhaps misleadingly named. It returns the first non-null value that it finds for that column i.e. pretty much any non-null value for that column within the group.

Sorting your rows before the groupBy doesn't affect the order within the group in any reproducible way.

See also this blog post which explains that, because of the behaviour above, the values you get from multiple first calls may not even be from the same row within the group.

Input data with 3 column ā€œk, t, vā€

z, 1, null
z, 2, 1.5
z, 3, 2.4

Code:

df.groupBy("k").agg(
  $"k",
  first($"t"),
  first($"v")
)

Output:

z, 1, 1.5

This result is a mix of 2 records!

DNA
  • 42,007
  • 12
  • 107
  • 146