1

I need to select the last 'name' for the given 'id'. A possible solution could be the following:

val channels = sessions
    .select($"start_time", $"id", $"name")
    .orderBy($"start_time")
    .select($"id", $"name")
    .groupBy($"id")
    .agg(last("name"))

I don't know if it's correct because I'm not sure that orderBy is kept after doing groupBy.

But it's certainly not a performant solution. Probably I should use reduceByKey. I tried the following in the spark shell and it works

val x = sc.parallelize(Array(("1", "T1"), ("2", "T2"), ("1", "T11"), ("1", "T111"), ("2", "T22"), ("1", "T100"), ("2", "T222"), ("2", "T200")), 3)
x.reduceByKey((acc,x) => x).collect

But it doesn't work with my dataframe.

case class ChannelRecord(id: Long, name: String)
val channels = sessions
    .select($"start_time", $"id", $"name")
    .orderBy($"start_time")
    .select($"id", $"name")
    .as[ChannelRecord]
    .reduceByKey((acc, x) => x) // take the last object

I got a compilation error: value reduceByKey is not a member of org.apache.spark.sql.Dataset

I think I should add a map() call before doing reduceByKey but I cannot figure out what should I map.

Oleg Pavliv
  • 20,462
  • 7
  • 59
  • 75

1 Answers1

4

You could do it with a window function for example. This will require a shuffle on a id column and a sort on start_time.

There are two stages:

  • Get last name for each id
  • Keep only rows with the last name (max start_time)

Example dataframe:

val rowsRdd: RDD[Row] = spark.sparkContext.parallelize(
Seq(
Row(1, "a",  1),
Row(1, "b",  2),
Row(1, "c",  3),
Row(2, "d",  4),
Row(2, "e",  5),
Row(2, "f",  6),
Row(3, "g",  7),
Row(3, "h",  8)
))


val schema: StructType = new StructType()
.add(StructField("id",             IntegerType,  false))
.add(StructField("name",           StringType,  false))
.add(StructField("start_time",     IntegerType, false))


val df0: DataFrame = spark.createDataFrame(rowsRdd, schema)

Define a window. Note that I am sorting here by start_time in decreasing order. This is to be able to choose first row in next step.

val w = Window.partitionBy("id").orderBy(col("start_time").desc)

Then

df0.withColumn("last_name", first("name").over(w)) // get first name for each id (first because of decreasing start_time)
.withColumn("row_number", row_number().over(w)) // get row number for each id sorted by start_time
.filter("row_number=1") // choose only first rows (first row = max start_time)
.drop("row_number") // get rid of row_number columns
.sort("id")
.show(10, false)

This returns

+---+----+----------+---------+
|id |name|start_time|last_name|
+---+----+----------+---------+
|1  |c   |3         |c        |
|2  |f   |6         |f        |
|3  |h   |8         |h        |
+---+----+----------+---------+
astro_asz
  • 2,278
  • 3
  • 15
  • 31