First, you need to know that histogram
generates two separate sequential jobs. One to detect the minimum and maximum of your data, one to compute the actual histogram. You can check this using the Spark UI.
We can follow the same scheme to build histograms on as many columns as you wish, with only two jobs. Yet, we cannot use the histogram
function which is only meant to handle one collection of doubles. We need to implement it by ourselves. The first job is dead simple.
val Row(min_trx : Double, max_trx : Double) = df.select(min('trx), max('trx)).head
Then we compute locally the ranges of the histogram. Note that I use the same ranges for all the columns. It allows to compare the results easily between the columns (by plotting them on the same figure). Having different ranges per column would just be a small modification of this code though.
val hist_size = 10
val hist_step = (max_trx - min_trx) / hist_size
val hist_ranges = (1 until hist_size)
.scanLeft(min_trx)((a, _) => a + hist_step) :+ max_trx
// I add max_trx manually to avoid rounding errors that would exclude the value
That was the first part. Then, we can use a UDF to determine in what range each value ends up, and compute all the histograms in parallel with spark.
val range_index = udf((x : Double) => hist_ranges.lastIndexWhere(x >= _))
val hist_df = df
.withColumn("rangeIndex", range_index('trx))
.groupBy("M1", "rangeIndex")
.count()
// And voilà, all the data you need is there.
hist_df.show()
+---+----------+-----+
| M1|rangeIndex|count|
+---+----------+-----+
| M2| 2| 2|
| M1| 0| 2|
| M2| 5| 1|
| M1| 3| 2|
| M2| 3| 1|
| M1| 7| 1|
| M2| 10| 1|
+---+----------+-----+
As a bonus, you can shape the data to use it locally (within the driver), either using the RDD API or by collecting the dataframe and modifying it in scala.
Here is one way to do it with spark since this is a question about spark ;-)
val hist_map = hist_df.rdd
.map(row => row.getAs[String]("M1") ->
(row.getAs[Int]("rangeIndex"), row.getAs[Long]("count")))
.groupByKey
.mapValues( _.toMap)
.mapValues( hists => (1 to hist_size)
.map(i => hists.getOrElse(i, 0L)).toArray )
.collectAsMap
EDIT: how to build one range per column value:
Instead of computing the min and max of M1, we compute it for each value of the column with groupBy
.
val min_max_map = df.groupBy("M1")
.agg(min('trx), max('trx))
.rdd.map(row => row.getAs[String]("M1") ->
(row.getAs[Double]("min(trx)"), row.getAs[Double]("max(trx)")))
.collectAsMap // maps each column value to a tuple (min, max)
Then we adapt the UDF so that it uses this map and we are done.
// for clarity, let's define a function that generates histogram ranges
def generate_ranges(min_trx : Double, max_trx : Double, hist_size : Int) = {
val hist_step = (max_trx - min_trx) / hist_size
(1 until hist_size).scanLeft(min_trx)((a, _) => a + hist_step) :+ max_trx
}
// and use it to generate one range per column value
val range_map = min_max_map.keys
.map(key => key ->
generate_ranges(min_max_map(key)._1, min_max_map(key)._2, hist_size))
.toMap
val range_index = udf((x : Double, m1 : String) =>
range_map(m1).lastIndexWhere(x >= _))
Finally, just replace range_index('trx)
by range_index('trx, 'M1)
and you will have one range per column value.