0

I have a dataset of videogames that includes all their sales per region (NA, EU, JP, Other) in each column.

Game NA_Sales EU sales JP Sales Other Sales
Wii Sports 10 5 8 2
Mario Kart 5 3 8 1

I want to create a function that will iterate each row and return the max value for each game. So when I run the UDF function to create a new column, it will return me 10 for Wii Sports and 8 for Mario Kart.

Any comment or help is highly appreciated.

  • Use [`greatest`](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.functions.greatest.html) function: `df.select(col("Game"), greatest(col("NA"), col("EU"), col("JP"), col("Other")).alias("max"))` – blackbishop Jan 24 '22 at 10:10
  • 1
    Does this answer your question? [how to calculate max value in some columns per row in pyspark](https://stackoverflow.com/questions/44833836/how-to-calculate-max-value-in-some-columns-per-row-in-pyspark) – blackbishop Jan 24 '22 at 10:15
  • @WinterSoldier it's a dataframe. – Joaco Bembhy Jan 24 '22 at 10:18
  • @blackbishop Yeah that made the job actually. Is there any way of doing it with a function though? Because I want to return the name of the column, not the number. Sorry for not being clear before. – Joaco Bembhy Jan 24 '22 at 10:19
  • Then see [this](https://stackoverflow.com/questions/56389696/select-column-name-per-row-for-max-value-in-pyspark) post – blackbishop Jan 24 '22 at 10:23
  • @blackbishop that was exactly what I was looking for. Thanks! – Joaco Bembhy Jan 24 '22 at 12:30

1 Answers1

0

For a UDF free solution, you can find the maximum sales across columns using greatest and then apply when to find the column containing this value.

from pyspark.sql import functions as F
from pyspark.sql import Column
from typing import List

data = [("Wii Sports", 10, 5, 8, 2,),
        ("Mario Kart", 5, 3, 8, 1,), ]

df = spark.createDataFrame(data, ("Game", "NA_Sales", "EU sales", "JP Sales", "Other Sales",))

def find_region_max_sales(cols: List[str]) -> Column:
    max_sales = F.greatest(*[F.col(c) for c in cols])
    max_col_expr = F
    for c in cols:
        max_col_expr = max_col_expr.when(F.col(c) == max_sales, c)
    return max_col_expr

df.withColumn("region_with_maximum_sales", find_region_max_sales(metric_cols)).show()

"""
+----------+--------+--------+--------+-----------+-------------------------+
|      Game|NA_Sales|EU sales|JP Sales|Other Sales|region_with_maximum_sales|
+----------+--------+--------+--------+-----------+-------------------------+
|Wii Sports|      10|       5|       8|          2|                 NA_Sales|
|Mario Kart|       5|       3|       8|          1|                 JP Sales|
+----------+--------+--------+--------+-----------+-------------------------+
"""
Nithish
  • 3,062
  • 2
  • 8
  • 16