0

Why this is not Spark sql top n per group:

I myself thought this was a duplicate, just now. But it isn't afterall. The difference is: I additionally need to make an aggregation beforehand. I edit my question below accordingly. So I need to have totalScore as the sum of all scores, the scores are used for sorting within a group and then dismissed. Only the code of the top ranked elements per grouped, shall enter the list, but the totalScore shall be made up from all the scores. So we cannot dismiss some elements per group and then later aggregate. We need to first aggregate and keep all elements and afterwards get rid of some. Now this could be done by splitting the original DataFrame in two do each thing separately and then join. But that does not sound very efficient.


I hava a Spark DataFrame that I with dem Schema

root
 |-- inputRowID: long (nullable = false)
 |-- score: double (nullable = true)
 |-- code: string (nullable = true)

and I want to do

val outDF = inDF.
  sort($"inputRowID", $"score".desc).
  groupBy("inputRowID").
  agg(
    sum($"score").as("totalScore"),
    collect_list($"code").as("list"))

getting outDF with the Schema

root
 |-- inputRowID: long (nullable = false)
 |-- totalScore: long (nullable = false)
 |-- list: array (nullable = true)
 |    |-- element: string (containsNull = true)

Now I want only the first n elements in the array to be kept. So I have been trying something like

outDF.
  map(r => Row(r(0), r(1).take(n)) )

(which of course does not work). Alternatively I thought about taking the frist n elements from the group, something like

val outDF = inDF.
  sort($"inputRowID", $"sorter".desc).
  groupBy("inputRowID").
  agg(take(n)).
  agg(
    collect_list($"code").as("list"))

but there is not function for that as far as I can see. Any ideas?

Community
  • 1
  • 1
Make42
  • 12,236
  • 24
  • 79
  • 155

0 Answers0