2

I have a column in a data frame. I need to aggregate a column by multiplying the values in this column instead of sum them up.

ex = spark.createDataFrame([[1,2],[4,5]],['a','b'])
ex.show()
ex.agg(f.sum('a')).show()

instead of sum I want to multiply column 'a' with syntax something like:

ex.agg(f.mul('a')).show()

the workaround I thought of is:

ex.agg(f.exp(f.sum(f.log('a')))).show()

however calculating exp(sum(log)) might be not efficient enough

The result should be 4. What is most efficient way?

idkman
  • 169
  • 1
  • 15
Eladwass
  • 21
  • 3
  • What version of spark? Spark 2.4+ supports [`aggregate`](https://spark.apache.org/docs/latest/api/sql/index.html#aggregate). – pault Jul 29 '19 at 14:46

2 Answers2

3

There is no built-in multiplicative aggregation. Your workaround seems efficient to me, other solutions require to build custom aggregation functions.

import pyspark.sql.functions as F
ex = spark.createDataFrame([[1,2],[4,5], [6,7], [3,2], [9,8], [4,2]],['a','b'])
ex.show()

+---+---+
|  a|  b|
+---+---+
|  1|  2|
|  4|  5|
|  6|  7|
|  3|  2|
|  9|  8|
|  4|  2|
+---+---+

# Solution 1
ex.agg(F.exp(F.sum(F.log('a')))).show()

+----------------+
|EXP(sum(LOG(a)))|
+----------------+
|          2592.0|
+----------------+

# Solution 2
from pyspark.sql.types import IntegerType

def mul_list(l):
    return reduce(lambda x,y: x*y, l)  # In Python 3, use `from functools import reduce`

udf_mul_list = F.udf(mul_list, IntegerType())
ex.agg(udf_mul_list(F.collect_list('a'))).show()

+-------------------------------+
|mul_list(collect_list(a, 0, 0))|
+-------------------------------+
|                           2592|
+-------------------------------+

# Solution 3
seqOp = (lambda local_result, row: local_result * row['a'] )
combOp = (lambda local_result1, local_result2: local_result1 * local_result2)
ex_rdd = ex.rdd
ex_rdd.aggregate( 1, seqOp, combOp)

Out[4]: 2592

Now let's compare performance :

import random
ex = spark.createDataFrame([[random.randint(1, 10), 3] for i in range(10000)],['a','b'])

%%timeit
ex.agg(F.exp(F.sum(F.log('a')))).count()

10 loops, best of 3: 84.9 ms per loop

%%timeit
ex.agg(udf_mul_list(F.collect_list('a'))).count()

10 loops, best of 3: 78.8 ms per loop

%%timeit
ex_rdd = ex.rdd
ex_rdd.aggregate( 1, seqOp, combOp)

10 loops, best of 3: 94.3 ms per loop

Performances seem about the same on one partition in local. Try on a bigger dataframe on several partitions.

For improved performances upon solution 2 and 3: build a custom aggregation function in Scala and wrap it in Python

Pierre Gourseaud
  • 2,347
  • 13
  • 24
  • Thank you!Do you know why there's no built in func for multiplication? Though solution 2 is more elegant I think it requires more memory. I will stick with solution 1. Thanks! – Eladwass Jul 29 '19 at 10:57
  • 1
    Solution #1 will not work if there is a 0 or negative number in column 'a'. – Xiaojie Zhou Aug 08 '22 at 20:18
1

When I see limitations in the python Spark API I always take a look at higher order functions as they give you access to functionality that might not be yet integrated to PySpark. Also, they generally give much better performance against UDFs as you use optimized native Spark operations. You can read more about higher-order functions here: https://medium.com/@danniesim/faster-and-more-concise-than-udf-spark-functions-and-higher-order-functions-with-pyspark-31d31de5fed8.

For your problem you could use f.aggegate, you can find some examples in the Spark documentation: https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.functions.aggregate.html#pyspark.sql.functions.aggregate. Here for reference how to aggregate values by multiplying:

ex.agg(f.aggregate('a', f.lit(1.0), lambda acc, x: acc * x))

EDIT: f.aggregate is available from PySpark 3.1.0, in case you have a previous version you can do the following (again, another higher-order function to use the aggregate from the Spark SQL API: https://spark.apache.org/docs/latest/api/sql/#aggregate):

ex
.agg(f.collect_list('a').alias('a'))
.withColumn('a', f.expr("aggregate(a, CAST(1.0 AS DOUBLE), (acc, x) -> acc * x, acc -> acc)")) 

Like this you use only native spark API but needless to say, really looks way too complicated for just multiplying values over a group.