5

How to implement a User Defined Aggregate Function (UDAF) in PySpark SQL?

pyspark version = 3.0.2
python version = 3.7.10

As a minimal example, I'd like to replace the AVG aggregate function with a UDAF:

sc = SparkContext()
sql = SQLContext(sc)
df = sql.createDataFrame(
    pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')
rv = sql.sql('SELECT id, AVG(value) FROM df GROUP BY id').toPandas()

where rv will be:

In [2]: rv
Out[2]:
   id  avg(value)
0   1         1.5
1   2         3.5

How can a UDAF replace AVG in the query?

For example this does not work

import numpy as np
def udf_avg(x):
    return np.mean(x)
sql.udf.register('udf_avg', udf_avg)
rv = sql.sql('SELECT id, udf_avg(value) FROM df GROUP BY id').toPandas()

The idea is to implement a UDAF in pure Python for processing not supported by SQL aggregate functions (e.g. a low-pass filter).

Russell Burdt
  • 2,391
  • 2
  • 19
  • 30
  • Does this answer your question? [Applying UDFs on GroupedData in PySpark (with functioning python example)](https://stackoverflow.com/questions/40006395/applying-udfs-on-groupeddata-in-pyspark-with-functioning-python-example) – blackbishop Mar 09 '21 at 07:14
  • No, because the ````pandas_udf```` definition has changed since Spark 3.0 – Russell Burdt Mar 09 '21 at 18:56

2 Answers2

4

A Pandas UDF can be used, where the definition is compatible from Spark 3.0 and Python 3.6+. See the issue and documentation for details.

Full implementation in Spark SQL:

import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(
    pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')

@pandas_udf(DoubleType())
def avg_udf(s: pd.Series) -> float:
    return s.mean()
spark.udf.register('avg_udf', avg_udf)

rv = spark.sql('SELECT id, avg_udf(value) FROM df GROUP BY id').toPandas()

with return value

In [2]: rv
Out[2]:
   id  avg_udf(value)
0   1             1.5
1   2             3.5
Russell Burdt
  • 2,391
  • 2
  • 19
  • 30
  • 1
    I think you meant `FloatType` because the signature uses `float`, but otherwise a nice improvement of my answer :) – mck Mar 09 '21 at 19:08
  • 1
    If you want to avoid deprecated features, I'd encourage you to use `SparkSession` instead of the (long)-deprecated `SQLContext`. – mck Mar 09 '21 at 19:09
  • ````SparkSession```` is the better choice, thank you for pointing that out :) Regarding ````FloatType```` vs ````DoubleType````, both work however I think the latter is the correct implementation because it is double-precision as is ````float````. Does seem unpythonic that we must specify the return value type twice and in different formats. Does anyone understand the reason behind that? – Russell Burdt Mar 09 '21 at 19:38
  • 1
    I don't know, but you can use a string `'double'`, which saves you an import and some typing... – mck Mar 09 '21 at 19:39
2

You can use a Pandas UDF with GROUPED_AGG type. It receives columns from Spark as Pandas Series, so that you can call Series.mean on the column.

import pyspark.sql.functions as F

@F.pandas_udf('float', F.PandasUDFType.GROUPED_AGG)  
def avg_udf(s):
    return s.mean()

df2 = df.groupBy('id').agg(avg_udf('value'))

df2.show()
+---+--------------+
| id|avg_udf(value)|
+---+--------------+
|  1|           1.5|
|  2|           3.5|
+---+--------------+

To register it for use in SQL is also possible:

df.createTempView('df')
spark.udf.register('avg_udf', avg_udf)

df2 = spark.sql("select id, avg_udf(value) from df group by id")
df2.show()
+---+--------------+
| id|avg_udf(value)|
+---+--------------+
|  1|           1.5|
|  2|           3.5|
+---+--------------+
mck
  • 40,932
  • 13
  • 35
  • 50
  • The solution you have provided is valid for Spark version before 3.0, see [this link](https://spark.apache.org/docs/latest/api/python/user_guide/arrow_pandas.html). Pandas UDF defintion has changed from Spark 3.0 with Python 3.6+. This is the specific UserWarning that is triggered ````In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for pandas UDF instead of specifying pandas UDF type which will be deprecated in the future releases. See SPARK-28264 for more details```` – Russell Burdt Mar 09 '21 at 18:48