How to implement a User Defined Aggregate Function (UDAF) in PySpark SQL?
pyspark version = 3.0.2
python version = 3.7.10
As a minimal example, I'd like to replace the AVG aggregate function with a UDAF:
sc = SparkContext()
sql = SQLContext(sc)
df = sql.createDataFrame(
pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')
rv = sql.sql('SELECT id, AVG(value) FROM df GROUP BY id').toPandas()
where rv will be:
In [2]: rv
Out[2]:
id avg(value)
0 1 1.5
1 2 3.5
How can a UDAF replace AVG
in the query?
For example this does not work
import numpy as np
def udf_avg(x):
return np.mean(x)
sql.udf.register('udf_avg', udf_avg)
rv = sql.sql('SELECT id, udf_avg(value) FROM df GROUP BY id').toPandas()
The idea is to implement a UDAF in pure Python for processing not supported by SQL aggregate functions (e.g. a low-pass filter).