2

My problem is similar to this one but instead of udf I need to use pandas_udf.

I have a spark data frame with many columns (number of columns varies) and I need to apply on them a custom function (for example sum). I know I can hard-code column names but it does not work when the number of columns varies.

Please see examples:

enter image description here

Alexandre B.
  • 5,387
  • 2
  • 17
  • 40
Grzegorz
  • 1,268
  • 11
  • 11

1 Answers1

0

The solution is to use the *expression in the function call and pd.concat method inside the pandas_udf function body

>>> import pandas as pd
>>> import pyspark.sql.functions as F

>>> @F.pandas_udf("double")
... def col_sum(*args: pd.Series) -> pd.Series:
...     pdf = pd.concat(args, axis=1)
...     col_sum = pdf.sum(axis=1)
...     return col_sum
... 

>>> df = spark.createDataFrame([(1,1,1),(2,2,2),(3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+                                                               
|  A|  B|  C|SUM|
+---+---+---+---+
|  1|  1|  1|3.0|
|  2|  2|  2|6.0|
|  3|  3|  3|9.0|
+---+---+---+---+

>>> df = spark.createDataFrame([(1,1,1,1),(2,2,2,2),(3,3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+----+
|  A|  B|  C| _4| SUM|
+---+---+---+---+----+
|  1|  1|  1|  1| 4.0|
|  2|  2|  2|  2| 8.0|
|  3|  3|  3|  3|12.0|
+---+---+---+---+----+
Grzegorz
  • 1,268
  • 11
  • 11