there is a summary()
api inside dataset which computes basicStats in the below format-
ds.summary("count", "min", "25%", "75%", "max").show()
// output:
// summary age height
// count 10.0 10.0
// min 18.0 163.0
// 25% 24.0 176.0
// 75% 32.0 180.0
// max 92.0 192.0
Similarly, You can enrich the dataframe apis to get the stats in the format you required as below-
Define RichDataframe
& implicits
to use
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{NumericType, StringType, StructField, StructType}
import scala.language.implicitConversions
class RichDataFrame(ds: DataFrame) {
def statSummary(statistics: String*): DataFrame = {
val defaultStatistics = Seq("max", "min", "mean", "std", "skewness", "kurtosis")
val statFunctions = if (statistics.nonEmpty) statistics else defaultStatistics
val selectedCols = ds.schema
.filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
.map(_.name)
val percentiles = statFunctions.filter(a => a.endsWith("%")).map { p =>
try {
p.stripSuffix("%").toDouble / 100.0
} catch {
case e: NumberFormatException =>
throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
}
}
require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
val aggExprs = selectedCols.flatMap(c => {
var percentileIndex = 0
statFunctions.map { stats =>
if (stats.endsWith("%")) {
val index = percentileIndex
percentileIndex += 1
expr(s"cast(percentile_approx($c, array(${percentiles.mkString(", ")}))[$index] as string)")
} else {
expr(s"cast($stats($c) as string)")
}
}
})
val aggResult = ds.select(aggExprs: _*).head()
val r = aggResult.toSeq.grouped(statFunctions.length).toArray
.zip(selectedCols)
.map{case(seq, column) => column +: seq }
.map(Row.fromSeq)
val output = StructField("columns", StringType) +: statFunctions.map(c => StructField(c, StringType))
val spark = ds.sparkSession
spark.createDataFrame(spark.sparkContext.parallelize(r), StructType(output))
}
}
object RichDataFrame {
trait Enrichment {
implicit def enrichMetadata(ds: DataFrame): RichDataFrame =
new RichDataFrame(ds)
}
object implicits extends Enrichment
}
Test with the provided test data as below
val df = Seq(
(10, 20, 30, 40, 50),
(100, 200, 300, 400, 500),
(111, 222, 333, 444, 555),
(1123, 2123, 3123, 4123, 5123),
(1321, 2321, 3321, 4321, 5321)
).toDF("col_1", "col_2", "col_3", "col_4", "col_5")
val columnsToCalculate = Seq("col_2","col_3","col_4")
import com.som.spark.shared.RichDataFrame.implicits._
df.selectExpr(columnsToCalculate: _*)
.statSummary("mean", "count", "25%", "75%", "90%")
.show(false)
/**
* +-------+------+-----+---+----+----+
* |columns|mean |count|25%|75% |90% |
* +-------+------+-----+---+----+----+
* |col_2 |977.2 |5 |200|2123|2321|
* |col_3 |1421.4|5 |300|3123|3321|
* |col_4 |1865.6|5 |400|4123|4321|
* +-------+------+-----+---+----+----+
*/