60

I am almost certain this has been asked before, but a search through stackoverflow did not answer my question. Not a duplicate of [2] since I want the maximum value, not the most frequent item. I am new to pyspark and trying to do something really simple: I want to groupBy column "A" and then only keep the row of each group that has the maximum value in column "B". Like this:

df_cleaned = df.groupBy("A").agg(F.max("B"))

Unfortunately, this throws away all other columns - df_cleaned only contains the columns "A" and the max value of B. How do I instead keep the rows? ("A", "B", "C"...)

pault
  • 41,343
  • 15
  • 107
  • 149
Thomas
  • 4,696
  • 5
  • 36
  • 71

4 Answers4

78

You can do this without a udf using a Window.

Consider the following example:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
    ('b', 3)
]
df = sqlCtx.createDataFrame(data, ["A", "B"])
df.show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  5|
#|  a|  8|
#|  a|  7|
#|  b|  1|
#|  b|  3|
#+---+---+

Create a Window to partition by column A and use this to compute the maximum of each group. Then filter out the rows such that the value in column B is equal to the max.

from pyspark.sql import Window
w = Window.partitionBy('A')
df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  8|
#|  b|  3|
#+---+---+

Or equivalently using pyspark-sql:

df.registerTempTable('table')
q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
sqlCtx.sql(q).show()
#+---+---+
#|  A|  B|
#+---+---+
#|  b|  3|
#|  a|  8|
#+---+---+
pault
  • 41,343
  • 15
  • 107
  • 149
  • 1
    I can't reproduce this solution (Spark 2.4). I get: `java.lang.UnsupportedOperationException: Cannot evaluate expression: max(input[1, bigint, false]) windowspecdefinition(input[0, string, true], specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))` – AltShift Jan 17 '19 at 22:27
  • @AltShift I suggest you create a [reproducible example](https://stackoverflow.com/questions/48427185/how-to-make-good-reproducible-apache-spark-examples) and write a new question. Be sure to mention you tried this solution and include the full traceback. (Are you sure you're using `pyspark.sql.functions.max` and not `__builtin__.max`?) – pault Jan 17 '19 at 22:32
  • 1
    Thanks @pault I've just received advice form Databricks that it is an issue with Spark 2.4. When they come back to me with their final analysis I should create a question and answer for the community. – AltShift Jan 23 '19 at 04:20
  • 1
    @AltShift; as someone that has hit the same bug, would it make sense to create the question already anyway so the rest of us have a place we can monitor for progress on this issue? – Jeroen Feb 01 '19 at 10:14
  • 1
    @Jeroen: Now documented: https://stackoverflow.com/questions/54508608/using-where-on-f-max-overwindow-on-spark-2-4-throws-java-exception/54508646#54508646 – AltShift Feb 04 '19 at 21:52
  • might want to use dense_rank() instead of max() if you want to ensure uniqueness by group. In my use case, i'm left joining the results of this back to to the original df to impute missings. The results need to be distinct. – justin cress Jun 10 '19 at 15:15
  • Why is `withColumn("maxB")....drop("maxB")` necessary? I checked, and `df.where(f.col("B") == f.max("B").over(w)).show()` fails with an UnsupportedOperationException java error, but I don't understand pyspark's internals well enough to know why – Zane Dufour Jun 13 '19 at 18:29
  • 1
    @ZaneDufour not spark-sql, but I believe [this answers your question](https://stackoverflow.com/questions/42470849/why-are-aggregate-functions-not-allowed-in-where-clause). – pault Jun 13 '19 at 19:02
  • @justincress that dense_rank() change does not ensure uniqueness, does it? If there are two rows with the maximum value for B, both would have dense_rank=1, right? – AugSB Jan 20 '22 at 11:38
20

Another possible approach is to apply join the dataframe with itself specifying "leftsemi". This kind of join includes all columns from the dataframe on the left side and no columns on the right side.

For example:

import pyspark.sql.functions as f
data = [
    ('a', 5, 'c'),
    ('a', 8, 'd'),
    ('a', 7, 'e'),
    ('b', 1, 'f'),
    ('b', 3, 'g')
]
df = sqlContext.createDataFrame(data, ["A", "B", "C"])
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  5|  c|
|  a|  8|  d|
|  a|  7|  e|
|  b|  1|  f|
|  b|  3|  g|
+---+---+---+

Max value of column B by by column A can be selected doing:

df.groupBy('A').agg(f.max('B')
+---+---+
|  A|  B|
+---+---+
|  a|  8|
|  b|  3|
+---+---+

Using this expression as a right side in a left semi join, and renaming the obtained column max(B) back to its original name B, we can obtain the result needed:

df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').show()
+---+---+---+
|  B|  A|  C|
+---+---+---+
|  3|  b|  g|
|  8|  a|  d|
+---+---+---+

The physical plan behind this solution and the one from accepted answer are different and it is still not clear to me which one will perform better on large dataframes.

The same result can be obtained using spark SQL syntax doing:

df.registerTempTable('table')
q = '''SELECT *
FROM table a LEFT SEMI
JOIN (
    SELECT 
        A,
        max(B) as max_B
    FROM table
    GROUP BY A
    ) t
ON a.A=t.A AND a.B=t.max_B
'''
sqlContext.sql(q).show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  b|  3|  g|
|  a|  8|  d|
+---+---+---+
ndricca
  • 490
  • 4
  • 13
  • The question was about getting the max value, not about keeping just one row. So actually this works with no regards on unique values in column B. Anyway if you want to keep only one row for each value of column A, you should go for `df.select("A","B",F.row_number().over(Window.partitionBy("A").orderBy("B", ascending=False)).alias("rn")).filter("rn = 1")` – ndricca May 03 '21 at 14:23
  • `ascending` isn't a method for `orderBy()` – Dan Oct 28 '22 at 18:01
6

There are two great solutions, so I decided to benchmark them. First let me define a bigger dataframe:

N_SAMPLES = 600000
N_PARTITIONS = 1000
MAX_VALUE = 100
data = zip([random.randint(0, N_PARTITIONS-1) for i in range(N_SAMPLES)],
          [random.randint(0, MAX_VALUE) for i in range(N_SAMPLES)],
          list(range(N_SAMPLES))
          )
df = spark.createDataFrame(data, ["A", "B", "C"])
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|118| 91|  0|
|439| 80|  1|
|779| 77|  2|
|444| 14|  3|
...

Benchmarking @pault's solution:

%%timeit
w = Window.partitionBy('A')
df_collect = df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .collect()

gives

655 ms ± 70.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Benchmarking @ndricca's solution:

%%timeit
df_collect = df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').collect()

gives

1 s ± 49.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So, @pault's solution seems to be 1.5x faster. Feedbacks on this benchmark are very welcome.

Fernando Wittmann
  • 1,991
  • 20
  • 16
3

just want to add scala spark version of @ndricca´s answer in case anyone needs it:

val data = Seq(("a", 5,"c"), ("a",8,"d"),("a",7,"e"),("b",1,"f"),("b",3,"g"))
val df = data.toDF("A","B","C")
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  5|  c|
|  a|  8|  d|
|  a|  7|  e|
|  b|  1|  f|
|  b|  3|  g|
+---+---+---+

val rightdf = df.groupBy("A").max("B")
rightdf.show()
+---+------+
|  A|max(B)|
+---+------+
|  b|     3|
|  a|     8|
+---+------+

val resdf = df.join(rightdf, df("B") === rightdf("max(B)"), "leftsemi")
resdf.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  8|  d|
|  b|  3|  g|
+---+---+---+

user9875189
  • 179
  • 1
  • 2
  • 10
  • 1
    Interesting, but perhaps not as relevant to a pyspark question? – de1 Jan 28 '21 at 14:59
  • 3
    I myself am searching for a way to achieve this in scala spark and landed on this question. I´m sure there is someone like me, just hope can save them time – user9875189 Jan 28 '21 at 15:10
  • I agree, I indeed would be happy if more people posted the alternative language solution (with a clear disclaimer that it is for another language) since Google's search algorithm often brings one to the wrong language, or questions are only answered in PySpark / Scala. – Thomas Sep 24 '21 at 13:42