For a UDF
free solution, you can find the maximum sales across columns using greatest
and then apply when
to find the column containing this value.
from pyspark.sql import functions as F
from pyspark.sql import Column
from typing import List
data = [("Wii Sports", 10, 5, 8, 2,),
("Mario Kart", 5, 3, 8, 1,), ]
df = spark.createDataFrame(data, ("Game", "NA_Sales", "EU sales", "JP Sales", "Other Sales",))
def find_region_max_sales(cols: List[str]) -> Column:
max_sales = F.greatest(*[F.col(c) for c in cols])
max_col_expr = F
for c in cols:
max_col_expr = max_col_expr.when(F.col(c) == max_sales, c)
return max_col_expr
df.withColumn("region_with_maximum_sales", find_region_max_sales(metric_cols)).show()
"""
+----------+--------+--------+--------+-----------+-------------------------+
| Game|NA_Sales|EU sales|JP Sales|Other Sales|region_with_maximum_sales|
+----------+--------+--------+--------+-----------+-------------------------+
|Wii Sports| 10| 5| 8| 2| NA_Sales|
|Mario Kart| 5| 3| 8| 1| JP Sales|
+----------+--------+--------+--------+-----------+-------------------------+
"""