Why this is not Spark sql top n per group:
I myself thought this was a duplicate, just now. But it isn't afterall. The difference is: I additionally need to make an aggregation beforehand. I edit my question below accordingly. So I need to have totalScore
as the sum of all score
s, the score
s are used for sorting within a group and then dismissed. Only the code
of the top ranked elements per grouped, shall enter the list, but the totalScore
shall be made up from all the score
s. So we cannot dismiss some elements per group and then later aggregate. We need to first aggregate and keep all elements and afterwards get rid of some. Now this could be done by splitting the original DataFrame in two do each thing separately and then join. But that does not sound very efficient.
I hava a Spark DataFrame
that I with dem Schema
root
|-- inputRowID: long (nullable = false)
|-- score: double (nullable = true)
|-- code: string (nullable = true)
and I want to do
val outDF = inDF.
sort($"inputRowID", $"score".desc).
groupBy("inputRowID").
agg(
sum($"score").as("totalScore"),
collect_list($"code").as("list"))
getting outDF
with the Schema
root
|-- inputRowID: long (nullable = false)
|-- totalScore: long (nullable = false)
|-- list: array (nullable = true)
| |-- element: string (containsNull = true)
Now I want only the first n
elements in the array to be kept. So I have been trying something like
outDF.
map(r => Row(r(0), r(1).take(n)) )
(which of course does not work). Alternatively I thought about taking the frist n
elements from the group, something like
val outDF = inDF.
sort($"inputRowID", $"sorter".desc).
groupBy("inputRowID").
agg(take(n)).
agg(
collect_list($"code").as("list"))
but there is not function for that as far as I can see. Any ideas?