6

I'm trying to aggregate a dataframe on multiple columns. I know that everything I need for the aggregation is within the partition- that is, there's no need for a shuffle because all of the data for the aggregation are local to the partition.

Taking an example, if I have something like

        val sales=sc.parallelize(List(
        ("West",  "Apple",  2.0, 10),
        ("West",  "Apple",  3.0, 15),
        ("West",  "Orange", 5.0, 15),
        ("South", "Orange", 3.0, 9),
        ("South", "Orange", 6.0, 18),
        ("East",  "Milk",   5.0, 5))).repartition(2)
        val tdf = sales.map{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
        reduceByKey((x, y) => (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4))
      println(tdf.toDebugString)

I get a result like

(2) ShuffledRDD[12] at reduceByKey at Test.scala:59 []
 +-(2) MapPartitionsRDD[11] at map at Test.scala:58 []
    |  MapPartitionsRDD[10] at repartition at Test.scala:57 []
    |  CoalescedRDD[9] at repartition at Test.scala:57 []
    |  ShuffledRDD[8] at repartition at Test.scala:57 []
    +-(1) MapPartitionsRDD[7] at repartition at Test.scala:57 []
       |  ParallelCollectionRDD[6] at parallelize at Test.scala:51 []

You can see the MapPartitionsRDD, which is good. But then there's the ShuffleRDD, which I want to prevent because I want the per-partition summarization, grouped by column values within the partition.

zero323's suggestion is tantalizingly close, but I need the "group by columns" functionality.

Referring to my sample above, I'm looking for the result that would be produced by

select store, prod, sum(amt), avg(units) from sales group by partition_id, store, prod

(I don't really need the partition id- that's just to illustrate that I want per-partition results)

I've looked at lots of examples but every debug string I've produced has the Shuffle. I really hope to get rid of the shuffle. I guess I'm essentially looking for a groupByKeysWithinPartitions function.

1472580
  • 163
  • 1
  • 8
  • Try this one `sales.mapPartitions(rdd => rdd.reduceByKey( same/expression/you/want ))` – mrsrinivas Oct 11 '17 at 16:45
  • That doesn't compile... the type of 'rdd' in your suggestion is Iterator, which has no 'reduceByKey' member function. Maybe I'm missing something? – 1472580 Oct 11 '17 at 17:11
  • It *does* have a map() function, I'm not sure if that could be part of the solution. [link](https://stackoverflow.com/questions/40892080/how-to-use-mappartitions-in-spark-scala) I tried various combinations of that and wasn't able to get it to compile – 1472580 Oct 11 '17 at 17:50

4 Answers4

4

The only way to achieve that is by using mapPartitions and have custom code for grouping and computing your values while iterating the partition. As you mention the data is already sorted by grouping keys (store, prod), we can efficiently compute your aggregations in a pipelined fashion:

(1) Define helper classes:

:paste

case class MyRec(store: String, prod: String, amt: Double, units: Int)

case class MyResult(store: String, prod: String, total_amt: Double, min_amt: Double, max_amt: Double, total_units: Int)

object MyResult {
  def apply(rec: MyRec): MyResult = new MyResult(rec.store, rec.prod, rec.amt, rec.amt, rec.amt, rec.units)

  def aggregate(result: MyResult, rec: MyRec) = {
    new MyResult(result.store,
      result.prod,
      result.total_amt + rec.amt,
      math.min(result.min_amt, rec.amt),
      math.max(result.max_amt, rec.amt),
      result.total_units + rec.units
    )
  }
}

(2) Define pipelined aggregator:

:paste

def pipelinedAggregator(iter: Iterator[MyRec]): Iterator[Seq[MyResult]] = {

var prev: MyResult = null
var res: Seq[MyResult] = Nil

for (crt <- iter) yield {
  if (prev == null) {
    prev = MyResult(crt)
  }
  else if (prev.prod != crt.prod || prev.store != crt.store) {
    res = Seq(prev)
    prev = MyResult(crt)
  }
  else {
    prev = MyResult.aggregate(prev, crt)
  }

  if (!iter.hasNext) {
    res = res ++ Seq(prev)
  }

  res
}

}

(3) Run aggregation:

:paste

val sales = sc.parallelize(
  List(MyRec("West", "Apple", 2.0, 10),
    MyRec("West", "Apple", 3.0, 15),
    MyRec("West", "Orange", 5.0, 15),
    MyRec("South", "Orange", 3.0, 9),
    MyRec("South", "Orange", 6.0, 18),
    MyRec("East", "Milk", 5.0, 5),
    MyRec("West", "Apple", 7.0, 11)), 2).toDS

sales.mapPartitions(iter => Iterator(iter.toList)).show(false)

val result = sales
  .mapPartitions(recIter => pipelinedAggregator(recIter))
  .flatMap(identity)

result.show
result.explain

Output:

    +-------------------------------------------------------------------------------------+
    |value                                                                                |
    +-------------------------------------------------------------------------------------+
    |[[West,Apple,2.0,10], [West,Apple,3.0,15], [West,Orange,5.0,15]]                     |
    |[[South,Orange,3.0,9], [South,Orange,6.0,18], [East,Milk,5.0,5], [West,Apple,7.0,11]]|
    +-------------------------------------------------------------------------------------+

    +-----+------+---------+-------+-------+-----------+
    |store|  prod|total_amt|min_amt|max_amt|total_units|
    +-----+------+---------+-------+-------+-----------+
    | West| Apple|      5.0|    2.0|    3.0|         25|
    | West|Orange|      5.0|    5.0|    5.0|         15|
    |South|Orange|      9.0|    3.0|    6.0|         27|
    | East|  Milk|      5.0|    5.0|    5.0|          5|
    | West| Apple|      7.0|    7.0|    7.0|         11|
    +-----+------+---------+-------+-------+-----------+

    == Physical Plan ==
    *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).store, true) AS store#31, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).prod, true) AS prod#32, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_amt AS total_amt#33, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).min_amt AS min_amt#34, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).max_amt AS max_amt#35, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_units AS total_units#36]
    +- MapPartitions <function1>, obj#30: $line14.$read$$iw$$iw$MyResult
       +- MapPartitions <function1>, obj#20: scala.collection.Seq
          +- Scan ExternalRDDScan[obj#4]
    sales: org.apache.spark.sql.Dataset[MyRec] = [store: string, prod: string ... 2 more fields]
    result: org.apache.spark.sql.Dataset[MyResult] = [store: string, prod: string ... 4 more fields]    
Traian
  • 1,474
  • 13
  • 11
  • Thanks, I'll give that a shot. Memory is important- is there some way to stream it out when I know I'm done with a 'chunk'? For my example, I do independently know that within a partition when the 'store' changes then I'm done with the previous store (that is, the 'store' is in a contiguous chunk of records) – 1472580 Oct 13 '17 at 15:14
  • As you have the data sorted by grouping keys, I updated the code to do a pipelined aggregation so it requires very little memory. – Traian Oct 14 '17 at 20:43
  • That's perfect... I was going nuts trying to figure it out. The Iterator bit is great- nearly all the mapPartitions examples are simple key-value pairs. – 1472580 Oct 15 '17 at 01:32
0

If this is the output your looking for

+-----+------+--------+----------+
|store|prod  |max(amt)|avg(units)|
+-----+------+--------+----------+
|South|Orange|6.0     |13.5      |
|West |Orange|5.0     |15.0      |
|East |Milk  |5.0     |5.0       |
|West |Apple |3.0     |12.5      |
+-----+------+--------+----------+

Spark Dataframe has all the functionality your asking for with generic concise shorthand syntax

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object TestJob2 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("West",  "Apple",  2.0, 10),
  ("West",  "Apple",  3.0, 15),
  ("West",  "Orange", 5.0, 15),
  ("South", "Orange", 3.0, 9),
  ("South", "Orange", 6.0, 18),
  ("East",  "Milk",   5.0, 5)
).toDF("store", "prod", "amt", "units")

rawDf.show(false)
rawDf.printSchema

val aggDf = rawDf
  .groupBy("store", "prod")
  .agg(
    max(col("amt")),
    avg(col("units"))
//        in case you need to retain more info
//        , collect_list(struct("*")).as("horizontal")
  )

aggDf.printSchema

aggDf.show(false)
  }
}

uncomment the collect_list line to aggregate everything

+-----+------+--------+----------+---------------------------------------------------+
|store|prod  |max(amt)|avg(units)|horizontal                                         
|
+-----+------+--------+----------+---------------------------------------------------+
|South|Orange|6.0     |13.5      |[[South, Orange, 3.0, 9], [South, Orange, 6.0, 18]]|
|West |Orange|5.0     |15.0      |[[West, Orange, 5.0, 15]]                          
|
|East |Milk  |5.0     |5.0       |[[East, Milk, 5.0, 5]]                             
|
|West |Apple |3.0     |12.5      |[[West, Apple, 2.0, 10], [West, Apple, 3.0, 15]]   |
+-----+------+--------+----------+---------------------------------------------------+
Rubber Duck
  • 3,673
  • 3
  • 40
  • 59
0

The maximum and average aggregations you specify are over multiple rows.

If you want to keep all the original rows use a Window function which will partition.

If you want to reduce the rows in each partition you must specify a reduction logic or filter.

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._


object TestJob7 {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      ("West",  "Apple",  2.0, 10),
      ("West",  "Apple",  3.0, 15),
      ("West",  "Orange", 5.0, 15),
      ("South", "Orange", 3.0, 9),
      ("South", "Orange", 6.0, 18),
      ("East",  "Milk",   5.0, 5)
    ).toDF("store", "prod", "amt", "units")


    rawDf.show(false)
    rawDf.printSchema

    val storeProdWindow = Window
      .partitionBy("store", "prod")

    val aggDf = rawDf
      .withColumn("max(amt)", max("amt").over(storeProdWindow))
      .withColumn("avg(units)", avg("units").over(storeProdWindow))

    aggDf.printSchema

    aggDf.show(false)
  }
}

Here is the result take note that its already grouped (the window shuffles into partitions)

+-----+------+---+-----+--------+----------+
|store|prod  |amt|units|max(amt)|avg(units)|
+-----+------+---+-----+--------+----------+
|South|Orange|3.0|9    |6.0     |13.5      |
|South|Orange|6.0|18   |6.0     |13.5      |
|West |Orange|5.0|15   |5.0     |15.0      |
|East |Milk  |5.0|5    |5.0     |5.0       |
|West |Apple |2.0|10   |3.0     |12.5      |
|West |Apple |3.0|15   |3.0     |12.5      |
+-----+------+---+-----+--------+----------+
Rubber Duck
  • 3,673
  • 3
  • 40
  • 59
0

Aggregate functions reduce values of rows for specified columns within the group. Yo can perform multiple different aggregations resulting in new columns with values from the input rows in one iteration, exclusively using Dataframe functionality. If you wish to retain other row values you need to implement reduction logic that specifies a row from which each value comes from. For instance keep all values of the first row with the maximum value of age. To this end you can use a UDAF (user defined aggregate function) to reduce rows within the group. In the example I also aggregate max amt and average units using standard aggregate functions in the same iteration.

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object ReduceAggJob {

  def main (args: Array[String]): Unit = {

    val appName = this.getClass.getName.replace("$", "")
    println(s"appName: $appName")

    val sparkSession = SparkSession
      .builder()
      .appName(appName)
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      ("West",  "Apple",  2.0, 10),
      ("West",  "Apple",  3.0, 15),
      ("West",  "Orange", 5.0, 15),
      ("West",  "Orange", 17.0, 15),
      ("South", "Orange", 3.0, 9),
      ("South", "Orange", 6.0, 18),
      ("East",  "Milk",   5.0, 5)
    ).toDF("store", "prod", "amt", "units")

    rawDf.printSchema
    rawDf.show(false)
    // Create an instance of UDAF GeometricMean.
    val maxAmtUdaf = new KeepRowWithMaxAmt

    // Keep the row with max amt
    val aggDf = rawDf
      .groupBy("store", "prod")
      .agg(
        max("amt"),
        avg("units"),
        maxAmtUdaf(
        col("store"),
        col("prod"),
        col("amt"),
        col("units")).as("KeepRowWithMaxAmt")
      )

    aggDf.printSchema
    aggDf.show(false)
  }
}

The UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(
      StructField("store", StringType) ::
      StructField("prod", StringType) ::
      StructField("amt", DoubleType) ::
      StructField("units", IntegerType) :: Nil
    )

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )


  // This is the output type of your aggregation function.
  override def dataType: DataType =
    StructType((Array(
      StructField("store", StringType),
      StructField("prod", StringType),
      StructField("amt", DoubleType),
      StructField("units", IntegerType)
    )))

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
    buffer(1) = ""
    buffer(2) = 0.0
    buffer(3) = 0
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    val amt = buffer.getAs[Double](2)
    val candidateAmt = input.getAs[Double](2)

    amt match {
      case a if a < candidateAmt =>
        buffer(0) = input.getAs[String](0)
        buffer(1) = input.getAs[String](1)
        buffer(2) = input.getAs[Double](2)
        buffer(3) = input.getAs[Int](3)
      case _ =>
    }
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer2.getAs[String](0)
    buffer1(1) = buffer2.getAs[String](1)
    buffer1(2) = buffer2.getAs[Double](2)
    buffer1(3) = buffer2.getAs[Int](3)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer
  }
}
Rubber Duck
  • 3,673
  • 3
  • 40
  • 59