There is no built-in multiplicative aggregation. Your workaround seems efficient to me, other solutions require to build custom aggregation functions.
import pyspark.sql.functions as F
ex = spark.createDataFrame([[1,2],[4,5], [6,7], [3,2], [9,8], [4,2]],['a','b'])
ex.show()
+---+---+
| a| b|
+---+---+
| 1| 2|
| 4| 5|
| 6| 7|
| 3| 2|
| 9| 8|
| 4| 2|
+---+---+
# Solution 1
ex.agg(F.exp(F.sum(F.log('a')))).show()
+----------------+
|EXP(sum(LOG(a)))|
+----------------+
| 2592.0|
+----------------+
# Solution 2
from pyspark.sql.types import IntegerType
def mul_list(l):
return reduce(lambda x,y: x*y, l) # In Python 3, use `from functools import reduce`
udf_mul_list = F.udf(mul_list, IntegerType())
ex.agg(udf_mul_list(F.collect_list('a'))).show()
+-------------------------------+
|mul_list(collect_list(a, 0, 0))|
+-------------------------------+
| 2592|
+-------------------------------+
# Solution 3
seqOp = (lambda local_result, row: local_result * row['a'] )
combOp = (lambda local_result1, local_result2: local_result1 * local_result2)
ex_rdd = ex.rdd
ex_rdd.aggregate( 1, seqOp, combOp)
Out[4]: 2592
Now let's compare performance :
import random
ex = spark.createDataFrame([[random.randint(1, 10), 3] for i in range(10000)],['a','b'])
%%timeit
ex.agg(F.exp(F.sum(F.log('a')))).count()
10 loops, best of 3: 84.9 ms per loop
%%timeit
ex.agg(udf_mul_list(F.collect_list('a'))).count()
10 loops, best of 3: 78.8 ms per loop
%%timeit
ex_rdd = ex.rdd
ex_rdd.aggregate( 1, seqOp, combOp)
10 loops, best of 3: 94.3 ms per loop
Performances seem about the same on one partition in local. Try on a bigger dataframe on several partitions.
For improved performances upon solution 2 and 3: build a custom aggregation function in Scala and wrap it in Python