0

I have a dataframe as follows:

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")


+---+---+---+
| F1| F2| F3|
+---+---+---+
|  x|  y|  1|
|  x|  z|  2|
|  x|  a|  4|
|  x|  a|  5|
|  t|  y|  1|
|  t| y2|  6|
|  t| y3|  3|
|  t| y4|  5|
+---+---+---+

How do I groupBy on column "F1", and multiply on "F3"?

For sum, I can do as follows, but not sure which function to use for multiplication.

df.groupBy("F1").agg(sum("F3")).show

+---+-------+
| F1|sum(F3)|
+---+-------+
|  x|     12|
|  t|     15|
+---+-------+
halfer
  • 19,824
  • 17
  • 99
  • 186
user3243499
  • 2,953
  • 6
  • 33
  • 75
  • You can create a **custom aggregation** for it. It you want to keep using a `DataFrame` then you need an [untyped-user-defined-aggregation](https://spark.apache.org/docs/latest/sql-getting-started.html#untyped-user-defined-aggregate-functions), I think the tutorial is clear enough. However if you still need help after reading it, don't doubt to edit the question with your attempt and tag me in a comment, I will answer with the implementation ;) – Luis Miguel Mejía Suárez Nov 30 '18 at 19:47
  • Is there a corresponding solution for pyspark? – Louis Yang Jun 12 '19 at 05:00

2 Answers2

3
val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")
import org.apache.spark.sql.Row
val x=df.select($"F1",$"F3").groupByKey{case r=>r.getString(0)}.reduceGroups{ ((r),(r2)) =>Row(r.getString(0),r.getInt(1)*r2.getInt(1)) }

x.show()

+-----+------------------------------------------+
|value|ReduceAggregator(org.apache.spark.sql.Row)|
+-----+------------------------------------------+
|    x|                                   [x, 40]|
|    t|                                   [t, 90]|
+-----+------------------------------------------+
Arnon Rotem-Gal-Oz
  • 25,469
  • 3
  • 45
  • 68
1

Define a custom aggregation function as follows :

class Product extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(StructField("value", LongType) :: Nil)

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(

    StructField("product", LongType) :: Nil
)

// This is the output type of your aggregatation function.
override def dataType: DataType = LongType

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = 1L

}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  buffer(0) = buffer.getAs[Long](0) * input.getAs[Long](0)
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer1.getAs[Long](0) * buffer2.getAs[Long](0)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer.getLong(0)
}

}

Then use it in aggregation as follows :

val product = new Product

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")

df.groupBy("F1").agg(product(col("F3"))).show

Here's the output :

+---+-----------+
| F1|product(F3)|
+---+-----------+
|  x|         40|
|  t|         90|
+---+-----------+
user238607
  • 1,580
  • 3
  • 13
  • 18