1

I am using an Aggregator to apply some custom merge on a DataFrame after grouping its records by their primary key:

case class Player(
  pk: String, 
  ts: String, 
  first_name: String, 
  date_of_birth: String
)

case class PlayerProcessed(
  var ts: String, 
  var first_name: String, 
  var date_of_birth: String
)

// Cutomer Aggregator -This just for the example, actual one is more complex
object BatchDedupe extends Aggregator[Player, PlayerProcessed, PlayerProcessed] {

  def zero: PlayerProcessed = PlayerProcessed("0", null, null)

  def reduce(bf: PlayerProcessed, in : Player): PlayerProcessed = {
    bf.ts = in.ts
    bf.first_name = in.first_name
    bf.date_of_birth = in.date_of_birth
    bf
  }

  def merge(bf1: PlayerProcessed, bf2: PlayerProcessed): PlayerProcessed = {
    bf1.ts = bf2.ts
    bf1.first_name = bf2.first_name
    bf1.date_of_birth = bf2.date_of_birth
    bf1
  }

  def finish(reduction: PlayerProcessed): PlayerProcessed = reduction
  def bufferEncoder: Encoder[PlayerProcessed] = Encoders.product
  def outputEncoder: Encoder[PlayerProcessed] = Encoders.product
}


val ply1 = Player("12121212121212", "10000001", "Rogger", "1980-01-02")
val ply2 = Player("12121212121212", "10000002", "Rogg", null)
val ply3 = Player("12121212121212", "10000004", null, "1985-01-02")
val ply4 = Player("12121212121212", "10000003", "Roggelio", "1982-01-02")

val seq_users = sc.parallelize(Seq(ply1, ply2, ply3, ply4)).toDF.as[Player]

val grouped = seq_users.groupByKey(_.pk)

val non_sorted = grouped.agg(BatchDedupe.toColumn.name("deduped"))
non_sorted.show(false)

This returns:

+--------------+--------------------------------+
|key           |deduped                         |
+--------------+--------------------------------+
|12121212121212|{10000003, Roggelio, 1982-01-02}|
+--------------+--------------------------------+

Now, I would like to order the records based on ts before aggregating them. From here I understand that .sortBy("ts") do not guarantee the order after the .groupByKey(_.pk). So I was trying to apply the .sortBy between the .groupByKey and the .agg

The output of the .groupByKey(_.pk) is a KeyValueGroupedDataset[String,Player], being the second element an Iterator. So to apply some sorting logic there I convert it into a Seq:

val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts))}.agg(BatchDedupe.toColumn.name("deduped"))
sorted.show(false)

However, the output of .mapGroups after adding the sorting logic is a Dataset[(String, Seq[Player])]. So when I try to invoke the .agg function on it I am getting the following exception:

Caused by: ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to $line050e0d37885948cd91f7f7dd9e3b4da9311.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Player

How could I convert back the output of my .mapGroups(...) into a KeyValueGroupedDataset[String,Player]?

I tried to cast back to Iterator as follows:

val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts).toIterator)}.agg(BatchDedupe.toColumn.name("deduped"))

But this approach produced the following exception:

UnsupportedOperationException: No Encoder found for Iterator[Player]
- field (class: "scala.collection.Iterator", name: "_2")
- root class: "scala.Tuple2"

How else can I add the sort logic between the .groupByKey and .agg methods?

Ignacio Alorre
  • 7,307
  • 8
  • 57
  • 94
  • The great benefit of Spark aggregators is that data is reduced (in parallel) on the mapper side before sending all individual rows over the wire. What you are trying to achieve is not going to work with aggregators... Why do you need the data sorted, is there no alternative way? E.g. min/max_by(value, timestamp) patterns? – Moritz Jun 27 '22 at 15:36
  • If you're willing to shuffle all data around (without the benefit of a map side reduce) you can use the secondary sort pattern and then apply a reduce function to your partition iterator. – Moritz Jun 27 '22 at 15:40
  • @Moritz Thanks for your advice. The size of the dataframe to aggregate is not that big (since comes in small batches), so shuffling could be tolerable. The required logic basically picks for each field the last value (order is defined by ts) as long as it is not null. That is why I was developing a udf – Ignacio Alorre Jun 27 '22 at 16:22
  • In that case you really don't have to bother about sorting. I'd recommend to timestamp every field in your aggregation buffer and simply always keep the latest non null value. alternatively you can use the build in aggregation function `max_by` to to the same : `input.groupBy(...).agg(columns.map(name -> max_by(col(name), col("ts")))).as(encoder)` – Moritz Jun 28 '22 at 09:07
  • @Moritz timestamp every field separately is an option I consider, but the actual number of columns is rather big. However, would it be possible you write as an answer the second option? I believe it can work and I will not even need the Aggregator if that max_by is ignoring the null values – Ignacio Alorre Jun 28 '22 at 10:50

1 Answers1

1

Based on the discussion above, the purpose of the Aggregator is to get the latest field values per Player by ts ignoring null values.

This can be achieved fairly easily aggregating all fields individually using max_by. With that there's no need for a custom Aggregator nor the mutable aggregation buffer.

import org.apache.spark.sql.functions._

val players: Dataset[Player] = ...

// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
   .filterNot(_ == "pk")
   .map(colName => expr(s"max_by($colName, if(isNotNull($colName), ts, null))").as(colName))

val aggregatedPlayers = players
   .groupBy(col("pk"))
   .agg(aggColumns.head, aggColumns.tail: _*)
   .as[Player]

On the most recent versions of Spark you can also use the build in max_by expression:

import org.apache.spark.sql.functions._

val players: Dataset[Player] = ...

// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
   .filterNot(_ == "pk")
   .map(colName => max_by(col(colName), when(col(colName).isNotNull, col("ts"))).as(colName))

val aggregatedPlayers = players
   .groupBy(col("pk"))
   .agg(aggColumns.head, aggColumns.tail: _*)
   .as[Player]
Moritz
  • 895
  • 4
  • 8
  • Thanks for the answer Moritz, but I am getting: Caused by: NotSerializableException: org.apache.spark.sql.Column Serialization stack: - object not serializable (class: org.apache.spark.sql.Column, value: max_by(ts, ts) AS ts) - element of array (index: 0) - array (class [Ljava.lang.Object;, size 3) - field (class: scala.collection.mutable.ArrayBuffer, name: array, type: class [Ljava.lang.Object;) - object (class scala.collection.mutable.ArrayBuffer, ArrayBuffer(max_by(ts, ts) AS ts, max_by(first_name, ts) AS first_name, max_by(date_of_birth, ts) AS date_of_birth)) - field (class: – Ignacio Alorre Jun 28 '22 at 15:52
  • 1
    This looks like your code contains a closure that depends on the current scope where `aggColumns` is defined and attempts to serialize everything. `Column` not being serializable is just a symptom of the actual problem here... In any case, you could just use `def aggColumns` instead of `val aggColumns` to prevent the issue – Moritz Jun 28 '22 at 16:18
  • thanks again, exception resolved and last value based on ts is selected. However null values are not ignored. For example, for the Dataset I used as example date_of_birth is "1985-01-02", but first_name is null – Ignacio Alorre Jun 28 '22 at 16:28
  • Strange, I don't think that's SQL conform ... anyways, you have to add a if/when clause then. I've updated above – Moritz Jun 28 '22 at 17:16
  • Perfect, thanks a lot Moritz – Ignacio Alorre Jun 28 '22 at 17:25