0

i am trying to use the agg function with type safe check ,i created a case class for the dataset and defined its schema

case class JosmSalesRecord(orderId: Long,
                           totalOrderSales : Double,
                           totalOrderCount: Long)
object JosmSalesRecord extends SparkSessionWrapper {
  import sparkSession.implicits._

  val schema: Option[StructType] = Some(StructType(Seq(
    StructField("order_id", IntegerType ,true),
    StructField("total_order_sales",DoubleType,true),

    StructField("total_order_count", IntegerType,true)

  )))

}

DataSet

     +----------+------------------+---------------+
    |   orderId|   totalOrderSales|totalOrderCount|
    +----------+------------------+---------------+
    |1071131089|            433.68|              8|
    |1071386263|  8848.42000000001|            343|
    |1071439146|108.39999999999999|              8|
    |1071349454|34950.400000000074|            512|
    |1071283654|349.65000000000003|             27|

root
     |-- orderId: long (nullable = false)
     |-- totalOrderSales: double (nullable = false)
     |-- totalOrderCount: long (nullable = false)

i am applying the following function on the dataset.

val pv = aggregateSum.agg(typed.sum[JosmSalesRecord](_.totalOrderSales),
          typed.sumLong[JosmSalesRecord](_.totalOrderCount)).toDF("totalOrderSales", "totalOrderCount")

when applying the pv.show(),i am getting the

java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to com.pipeline.dataset.JosmSalesRecord
at com.FirstApp$$anonfun$7.apply(FirstApp.scala:78)
at org.apache.spark.sql.execution.aggregate.TypedSumDouble.reduce(typedaggregators.scala:32)
at org.apache.spark.sql.execution.aggregate.TypedSumDouble.reduce(typedaggregators.scala:30)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.agg_doConsume_1$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.serializefromobject_doConsume_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.mapelements_doConsume_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.deserializetoobject_doConsume_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.agg_doAggregateWithKeysOutput_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.agg_doAggregateWithoutKey_0$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage9.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55)
at org.apache.spark.scheduler.Task.run(Task.scala:123)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)

Note -- i Strictly want to use import org.apache.spark.sql.expressions.scalalang.typed.sum typed safe function. I am getting my answer while applying the sum of import org.apache.spark.sql.functions

Amit
  • 53
  • 13

1 Answers1

1

If aggregateSum is a DataFrame, e.g.

val aggregateSum = Seq((1071131089L, 433.68, 8),(1071386263L,8848.42000000001,343)).toDF("orderId", "totalOrderSales", "totalOrderCount")

In current versions of Spark (i.e. 3.x), typed is deprecated. You can still use type safety with something like this (Spark 3):

case class OutputRecord(totalOrderSales : Double, totalOrderCount: Long)


val pv = aggregateSum.as[JosmSalesRecord]
            .groupByKey(_ => 1)
            .mapGroups((k,vs) => vs.map((x:JosmSalesRecord) => (x.totalOrderSales, x.totalOrderCount))
            .reduceOption((x, y) => (x._1 + y._1, x._2 + y._2))
            .map{case (x, y) => OutputRecord(x,y)}
            .getOrElse(OutputRecord(0.0, 0L)))
pv.show()

gives

+----------------+---------------+
| totalOrderSales|totalOrderCount|
+----------------+---------------+
|9282.10000000001|            351|
+----------------+---------------+

If you have Scala cats as a dependency already, then you can also do something like this

ELinda
  • 2,658
  • 1
  • 10
  • 9