I have a Spark Dataset, and I would like to group the data and process the groups, yielding zero or one element per each group. Something like:
val resulDataset = inputDataset
.groupBy('x, 'y)
.flatMap(...)
I didn't find a way to apply a function after a groupBy
, but it appears I can use groupByKey
instead (is it a good idea? is there a better way?):
val resulDataset = inputDataset
.groupByKey(v => (v.x, v.y))
.flatMap(...)
This works, but here is a thing: I would like process the groups as Dataset
s. The reason is that I already have convenient functions to use on Datasets and would like to reuse them when calculating the result for each group. But, the groupByKey.flatMap
yields an Iterator
over the grouped elements, not the Dataset
.
The question: is there a way in Spark to group an input Dataset and map a custom function to each group, while treating the grouped elements as a Dataset ? E.g.:
val inputDataset: Dataset[T] = ...
val resulDataset: Dataset[U] = inputDataset
.groupBy(...)
.flatMap(group: Dataset[T] => {
// using Dataset API to calculate resulting value, e.g.:
group.withColumn(row_number().over(...))....as[U]
})
Note, that grouped data is bounded, and it is OK to process it on a single node. But the number of groups can be very high, so the resulting Dataset needs to be distributed. The point of using the Dataset API to process a group is purely a question of using a convenient API.
What I tried so far:
creating a Dataset from an Iterator in the mapped function - it fails with an NPE from a SparkSession (my understanding is that it boils down to the fact that one cannot create a Dataset within the functions which process a Dataset; see this and this)
tried to overcome the issues in the first solution, attempted to create new SparkSession to create the Dataset within a new session; fails with NPE from SparkSession.newSession
(ab)using
repartition('x, 'y).mapPartitions(...)
, but this also yields anIterator[T]
for each partition, not aDataset[T]
finally, (ab)using
filter
: I can collect all distinct values of the grouping criteria into an Array (select.distinct.collect
), and iterate this array to filter the source Dataset, yielding one Dataset for each group (sort of joins the idea ofmultiplexing
from this article); although this works, my understanding is that it collects all the data on a single node, so it doesn't scale and will eventually have memory issues