13

Given the following DataFrame:

+----+-----+---+-----+
| uid|    k|  v|count|
+----+-----+---+-----+
|   a|pref1|  b|  168|
|   a|pref3|  h|  168|
|   a|pref3|  t|   63|
|   a|pref3|  k|   84|
|   a|pref1|  e|   84|
|   a|pref2|  z|  105|
+----+-----+---+-----+

How can I get the max value from uid, k but include v?

+----+-----+---+----------+
| uid|    k|  v|max(count)|
+----+-----+---+----------+
|   a|pref1|  b|       168|
|   a|pref3|  h|       168|
|   a|pref2|  z|       105|
+----+-----+---+----------+

I can do something like this but it will drop the column "v" :

df.groupBy("uid", "k").max("count")
zero323
  • 322,348
  • 103
  • 959
  • 935
jfgosselin
  • 395
  • 1
  • 2
  • 10

3 Answers3

15

It's the perfect example for window operators (using over function) or join.

Since you've already figured out how to use windows, I focus on join exclusively.

scala> val inventory = Seq(
     |   ("a", "pref1", "b", 168),
     |   ("a", "pref3", "h", 168),
     |   ("a", "pref3", "t",  63)).toDF("uid", "k", "v", "count")
inventory: org.apache.spark.sql.DataFrame = [uid: string, k: string ... 2 more fields]

scala> val maxCount = inventory.groupBy("uid", "k").max("count")
maxCount: org.apache.spark.sql.DataFrame = [uid: string, k: string ... 1 more field]

scala> maxCount.show
+---+-----+----------+
|uid|    k|max(count)|
+---+-----+----------+
|  a|pref3|       168|
|  a|pref1|       168|
+---+-----+----------+

scala> val maxCount = inventory.groupBy("uid", "k").agg(max("count") as "max")
maxCount: org.apache.spark.sql.DataFrame = [uid: string, k: string ... 1 more field]

scala> maxCount.show
+---+-----+---+
|uid|    k|max|
+---+-----+---+
|  a|pref3|168|
|  a|pref1|168|
+---+-----+---+

scala> maxCount.join(inventory, Seq("uid", "k")).where($"max" === $"count").show
+---+-----+---+---+-----+
|uid|    k|max|  v|count|
+---+-----+---+---+-----+
|  a|pref3|168|  h|  168|
|  a|pref1|168|  b|  168|
+---+-----+---+---+-----+
Jacek Laskowski
  • 72,696
  • 27
  • 242
  • 420
12

Here's the best solution I came up with so far:

val w = Window.partitionBy("uid","k").orderBy(col("count").desc)

df.withColumn("rank", dense_rank().over(w)).select("uid", "k","v","count").where("rank == 1").show
jfgosselin
  • 395
  • 1
  • 2
  • 10
11

You can use window functions:

from pyspark.sql.functions import max as max_
from pyspark.sql.window import Window

w = Window.partitionBy("uid", "k")

df.withColumn("max_count", max_("count").over(w))
1d210d2d0
  • 111
  • 2