I am trying to calculate weighted mean in pyspark but not making a lot of progress
# Example data
df = sc.parallelize([
("a", 7, 1), ("a", 5, 2), ("a", 4, 3),
("b", 2, 2), ("b", 5, 4), ("c", 1, -1)
]).toDF(["k", "v1", "v2"])
df.show()
import numpy as np
def weighted_mean(workclass, final_weight):
return np.average(workclass, weights=final_weight)
weighted_mean_udaf = pyspark.sql.functions.udf(weighted_mean,
pyspark.sql.types.IntegerType())
but when I try to execute this code
df.groupby('k').agg(weighted_mean_udaf(df.v1,df.v2)).show()
I am getting the error
u"expression 'pythonUDF' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get
My question is, can I specify a custom function ( taking multiple arguments) as argument to agg? If not, is there any alternative to perform operations like weighted mean after grouping by a key?