7

In Scala/Spark, having a dataframe:

val dfIn = sqlContext.createDataFrame(Seq(
  ("r0", 0, 2, 3),
  ("r1", 1, 0, 0),
  ("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")

I would like to compute a new column maxCol holding the name of the column corresponding to the max value (for each row). With this example, the output should be:

+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

Actually the dataframe have more than 60 columns. Thus a generic solution is required.

The equivalent in Python Pandas (yes, I know, I should compare with pyspark...) could be:

dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1) 
zero323
  • 322,348
  • 103
  • 959
  • 935
ivankeller
  • 1,923
  • 1
  • 19
  • 20
  • 1
    How many number of columns may you have in general? – mrsrinivas Feb 27 '17 at 11:56
  • I have around 60 columns – ivankeller Feb 27 '17 at 12:59
  • how many will in comparison for max column ? – mrsrinivas Feb 27 '17 at 13:08
  • 2
    I don't know if it is a duplicate of this [question](http://stackoverflow.com/questions/42030486/scala-spark-in-dataframe-retrieve-for-row-column-name-with-have-max-value/42486873#42486873). Clearly having 60 columns changes the set of viable solutions. Anyway, there is the following [answer](http://stackoverflow.com/a/42486873/3297229) – Wilmerton Feb 27 '17 at 13:31
  • Thanks @Wilmerton! That's a nice solution showing the elegance of Scala/Spark dataframes vs. Python/Pandas dataframes ;) (again, pandas dataframes are not distributed thus the comparison is not really relevant) – ivankeller Feb 27 '17 at 13:38

1 Answers1

14

With a small trick you can use greatest function. Required imports:

import org.apache.spark.sql.functions.{col, greatest, lit, struct}

First let's create a list of structs, where the first element is value, and the second one column name:

val structs = dfIn.columns.tail.map(
  c => struct(col(c).as("v"), lit(c).as("k"))
)

Structure like this can be passed to greatest as follows:

dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c2|
+---+---+---+---+------+

Please note that in case of ties it will take the element which occurs later in the sequence (lexicographically (x, "c2") > (x, "c1")). If for some reason this is not acceptable you can explicitly reduce with when:

import org.apache.spark.sql.functions.when

val max_col = structs.reduce(
  (c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
).getItem("k")

dfIn.withColumn("maxCol", max_col)
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

In case of nullable columns you have to adjust this, for example by coalescing to values to -Inf.

zero323
  • 322,348
  • 103
  • 959
  • 935
  • is there any way i can modify the code above to also return the 2nd place max column and the third place max column? (imagine we have more columns than c0, c1, c2 and we want to return the top 3 columns) – Hana Feb 23 '22 at 16:57