194

I have a DataFrame generated as follow:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

The results look like:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

As you can see, the DataFrame is ordered by Hour in an increasing order, then by TotalValue in a descending order.

I would like to select the top row of each group, i.e.

  • from the group of Hour==0 select (0,cat26,30.9)
  • from the group of Hour==1 select (1,cat67,28.5)
  • from the group of Hour==2 select (2,cat56,39.6)
  • and so on

So the desired output would be:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

It might be handy to be able to select the top N rows of each group as well.

Any help is highly appreciated.

zero323
  • 322,348
  • 103
  • 959
  • 935
Rami
  • 8,044
  • 18
  • 66
  • 108

10 Answers10

300

Window functions:

Something like this should do the trick:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future (SPARK-37099).

Plain SQL aggregation followed by join:

Alternatively you can join with aggregated data frame:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Using ordering over structs:

Neat, although not very well tested, trick which doesn't require joins or window functions:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

With DataSet API (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 or later:

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

Don't use:

df.orderBy(...).groupBy(...).agg(first(...), ...)

It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

The same note applies to

df.orderBy(...).dropDuplicates(...)

which internally uses equivalent execution plan.

10465355
  • 4,481
  • 2
  • 20
  • 44
zero323
  • 322,348
  • 103
  • 959
  • 935
  • 3
    It looks like since spark 1.6 it is [row_number()](https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.functions$@row_number():org.apache.spark.sql.Column) instead of rowNumber – Adam Szałucha Sep 15 '17 at 13:53
  • About the Don't use df.orderBy(...).gropBy(...). Under what circumstances can we rely on orderBy(...)? or if we can not be sure if orderBy() is going to give the correct result, what alternatives do we have? – Ignacio Alorre Sep 27 '17 at 12:35
  • 1
    I might be overlooking something, but in general it is recommended to [avoid groupByKey](https://databricks.gitbooks.io/databricks-spark-knowledge-base/content/best_practices/prefer_reducebykey_over_groupbykey.html), instead reduceByKey should be used. Also, you'll be saving one line. – Thomas Feb 19 '18 at 15:17
  • 3
    @Thomas avoiding groupBy/groupByKey is just when dealing with RDDs, you'll notice that the Dataset api doesn't even have a reduceByKey function. – soote May 17 '18 at 04:19
  • @Thomas [DataFrame / Dataset groupBy behaviour/optimization](https://stackoverflow.com/q/32902982/8371915) – Alper t. Turker Jun 06 '18 at 21:57
  • Update: both of those spark issues have been resolved, `first` should be good to use – Brendan May 12 '20 at 03:42
  • @Brendan from which version does it good to use? – Abuw Feb 17 '21 at 04:33
  • I’m not sure how to exactly read Sparks issue tracker to confirm which version they fixed this in, but following the links above shows both issues resolved. One had the affected version listed as 1.6.1 and the other as 3.1.0. – Brendan Feb 18 '21 at 05:06
  • 2
    if you read both tickets attentively, you'll notice that while they are marked as "resolved", they didn't change any underlying behavior for `first()`. They only document that the output is non-deterministic. As such, `first()` without a window function remains unreliable to use... – hiryu Aug 26 '21 at 10:09
  • Compare this `df.as[Record] .groupByKey(_.Hour) .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y) ` with the spark3 , maxBy API, which one would perform better? @zero323 – soMuchToLearnAndShare May 16 '23 at 03:30
23

For Spark 2.0.2 with grouping by multiple columns:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
  • This code is more or less contained in [Apache DataFu's](http://datafu.apache.org/) [dedupWithOrder method](https://github.com/apache/datafu/blob/master/datafu-spark/src/main/scala/datafu/spark/SparkDFUtils.scala#L139) – Eyal Jul 26 '21 at 10:13
11

This is a exact same of zero323's answer but in SQL query way.

Assuming that dataframe is created and registered as

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Window function :

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Plain SQL aggregation followed by join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Using ordering over structs:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets way and don't dos are same as in original answer

Community
  • 1
  • 1
Ramesh Maharjan
  • 41,071
  • 6
  • 69
  • 97
4

You can use max_by() function (Spark 3.0+) !

https://spark.apache.org/docs/3.0.0-preview/api/sql/index.html#max_by

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("table")

// Using SQL
val result = spark.sql("select Hour, max_by(Category, TotalValue) AS Category, max(TotalValue) as TotalValue FROM table group by Hour order by Hour")

// or Using DataFrame API
val result = df.groupBy("Hour").
  agg(expr("max_by(Category, TotalValue)").as("Category"), max("TotalValue").as("TotalValue")).
  sort("Hour")

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
+----+--------+----------+
Heedo Lee
  • 59
  • 3
3

You can do this easily with Apache DataFu (the implementation is similar to Antonin's answer).

import datafu.spark.DataFrameOps._

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

df.dedupWithOrder($"Hour", $"TotalValue".desc).show

which will result in

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   3|    cat8|      35.6|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
+----+--------+----------+

(yes, the result will not be ordered by Hour, but you can always do that later if it's important)

There's also an API - dedupTopN - for taking the top N rows. And another API - dedupWithCombiner - when you expect a large number of rows per grouping.

(full disclosure - I'm part of the DataFu project)

Eyal
  • 3,412
  • 1
  • 44
  • 60
2

The solution below does only one groupBy and extract the rows of your dataframe that contain the maxValue in one shot. No need for further Joins, or Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df
    .groupByKey{(r) => r.getInt(0)}
    .mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
Geoff Langenderfer
  • 746
  • 1
  • 9
  • 21
elghoto
  • 303
  • 1
  • 9
  • But it shuffles everything first. It is hardly an improvement (maybe not worse than window functions, depending on the data). – Alper t. Turker Jun 06 '18 at 22:00
  • you have a group first place, that will triggers a shuffle. It's not worse than window function because in a window function it's going to evaluate the window for each single row in the dataframe. – elghoto Jun 06 '18 at 22:09
2

A nice way of doing this with the dataframe api is using the argmax logic like so

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+
randal25
  • 1,290
  • 13
  • 10
2

The pattern is group by keys => do something to each group e.g. reduce => return to dataframe

I thought the Dataframe abstraction is a bit cumbersome in this case so I used RDD functionality

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
Rubber Duck
  • 3,673
  • 3
  • 40
  • 59
0

Here you can do like this -

val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show
Koedlt
  • 4,286
  • 8
  • 15
  • 33
0

While windows function as proposed by @zero323 I found that using latest_per_key = groupby(*keys).agg(F.max(F.col('timestamp')) to find the latest sample and then use data.join(latest_per_key, how='semileft', on= keys + 'timestamp') works faster.

Hanan Shteingart
  • 8,480
  • 10
  • 53
  • 66