Define a custom aggregation function as follows :
class Product extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", LongType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("product", LongType) :: Nil
)
// This is the output type of your aggregatation function.
override def dataType: DataType = LongType
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 1L
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) * input.getAs[Long](0)
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) * buffer2.getAs[Long](0)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)
}
}
Then use it in aggregation as follows :
val product = new Product
val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")
df.groupBy("F1").agg(product(col("F3"))).show
Here's the output :
+---+-----------+
| F1|product(F3)|
+---+-----------+
| x| 40|
| t| 90|
+---+-----------+