1

I try to implement a cumulative product in Spark Scala, but I really don't know how to it. I have the following dataframe:

Input data:
+--+--+--------+----+
|A |B | date   | val|
+--+--+--------+----+
|rr|gg|20171103| 2  |
|hh|jj|20171103| 3  |
|rr|gg|20171104| 4  |
|hh|jj|20171104| 5  |
|rr|gg|20171105| 6  |
|hh|jj|20171105| 7  |
+-------+------+----+

And I would like to have the following output:

Output data:
+--+--+--------+-----+
|A |B | date   | val |
+--+--+--------+-----+
|rr|gg|20171105| 48  | // 2 * 4 * 6
|hh|jj|20171105| 105 | // 3 * 5 * 7
+-------+------+-----+
ZygD
  • 22,092
  • 39
  • 79
  • 102
Marc Lamberti
  • 763
  • 2
  • 9
  • 24
  • I'm currently working on a [pull-request](https://github.com/rwpenney/spark/tree/feature/agg-product) to provide a native Spark solution for this. While it's certainly possible to use an `exp(sum(log(...)))` there are some pitfalls to beware of. There's an [Spark JIRA issue](https://issues.apache.org/jira/browse/SPARK-33678) covering this - I'd welcome other's commenting on use-cases for this sort of aggregation function. – rwp Dec 08 '20 at 17:06

4 Answers4

8

As long as the number are strictly positive (0 can be handled as well, if present, using coalesce) as in your example, the simplest solution is to compute the sum of logarithms and take the exponential:

import org.apache.spark.sql.functions.{exp, log, max, sum}

val df = Seq(
  ("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3), 
  ("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5), 
  ("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")

val result = df
  .groupBy("A", "B")
  .agg(
    max($"date").as("date"), 
    exp(sum(log($"val"))).as("val"))

Since this uses FP arithmetic the result won't be exact:

result.show
+---+---+--------+------------------+
|  A|  B|    date|               val|
+---+---+--------+------------------+
| hh| jj|20171105|104.99999999999997|
| rr| gg|20171105|47.999999999999986|
+---+---+--------+------------------+

but after rounding should good enough for majority of applications.

result.withColumn("val", round($"val")).show
+---+---+--------+-----+
|  A|  B|    date|  val|
+---+---+--------+-----+
| hh| jj|20171105|105.0|
| rr| gg|20171105| 48.0|
+---+---+--------+-----+

If that's not enough you can define an UserDefinedAggregateFunction or Aggregator (How to define and use a User-Defined Aggregate Function in Spark SQL?) or use functional API with reduceGroups:

import scala.math.Ordering

case class Record(A: String, B: String, date: String, value: Long)

df.withColumnRenamed("val", "value").as[Record]
  .groupByKey(x => (x.A, x.B))
  .reduceGroups((x, y) => x.copy(
    date = Ordering[String].max(x.date, y.date),
    value = x.value * y.value))
  .toDF("key", "value")
  .select($"value.*")
  .show
+---+---+--------+-----+
|  A|  B|    date|value|
+---+---+--------+-----+
| hh| jj|20171105|  105|
| rr| gg|20171105|   48|
+---+---+--------+-----+
zero323
  • 322,348
  • 103
  • 959
  • 935
2

You can solve this using either collect_list+UDF or an UDAF. UDAF may be more efficient, but harder to implement due to the local aggregation.

If you have a dataframe like this :

+---+---+
|key|val|
+---+---+
|  a|  1|
|  a|  2|
|  a|  3|
|  b|  4|
|  b|  5|
+---+---+

You can invoke an UDF :

val prod = udf((vals:Seq[Int]) => vals.reduce(_ * _))

df
  .groupBy($"key")
  .agg(prod(collect_list($"val")).as("val"))
  .show()

+---+---+
|key|val|
+---+---+
|  b| 20|
|  a|  6|
+---+---+
Raphael Roth
  • 26,751
  • 15
  • 88
  • 145
0

Since Spark 2.4, you could also compute this using the higher order function aggregate:

import org.apache.spark.sql.functions.{expr, max}
val df = Seq(
  ("rr", "gg", "20171103", 2),
  ("hh", "jj", "20171103", 3),
  ("rr", "gg", "20171104", 4),
  ("hh", "jj", "20171104", 5),
  ("rr", "gg", "20171105", 6),
  ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")

val result = df
  .groupBy("A", "B")
  .agg(
    max($"date").as("date"),
    expr("""
   aggregate(
     collect_list(val),
     cast(1 as bigint),
     (acc, x) -> acc * x)""").alias("val")
  )
Oliver W.
  • 13,169
  • 3
  • 37
  • 50
0

Spark 3.2+

product(e: Column): Column
Aggregate function: returns the product of all numerical elements in a group.

Scala

import spark.implicits._
var df = Seq(
    ("rr", "gg", 20171103, 2),
    ("hh", "jj", 20171103, 3),
    ("rr", "gg", 20171104, 4),
    ("hh", "jj", 20171104, 5),
    ("rr", "gg", 20171105, 6),
    ("hh", "jj", 20171105, 7)
).toDF("A", "B", "date", "val")

df = df.groupBy("A", "B").agg(max($"date").as("date"), product($"val").as("val"))
df.show(false)
// +---+---+--------+-----+
// |A  |B  |date    |val  |
// +---+---+--------+-----+
// |hh |jj |20171105|105.0|
// |rr |gg |20171105|48.0 |
// +---+---+--------+-----+

PySpark

from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.getOrCreate()
data = [('rr', 'gg', 20171103, 2),
        ('hh', 'jj', 20171103, 3),
        ('rr', 'gg', 20171104, 4),
        ('hh', 'jj', 20171104, 5),
        ('rr', 'gg', 20171105, 6),
        ('hh', 'jj', 20171105, 7)]
df = spark.createDataFrame(data, ['A', 'B', 'date', 'val'])

df = df.groupBy('A', 'B').agg(F.max('date').alias('date'), F.product('val').alias('val'))
df.show()
#+---+---+--------+-----+
#|  A|  B|    date|  val|
#+---+---+--------+-----+
#| hh| jj|20171105|105.0|
#| rr| gg|20171105| 48.0|
#+---+---+--------+-----+
ZygD
  • 22,092
  • 39
  • 79
  • 102