0

In spark-sql I have a dataframe with column col that contains an array of Int of size 100 (for instance).

I want to aggregate this column into a single value that is an array of Int of size 100, that contains the sum of each element of the column. It is possible to do this by calling:

dataframe.agg(functions.array((0 until 100).map(i => functions.sum(i)) : _*))

This will generate code to explicitely do 100 aggregations then present the 100 results as an array of 100 items. However this seems very inefficient as catalyst will even fail to generate the code for this if my array size exceeds ~1000 items. Is there a construct in spark-sql to do this more efficiently? Ideally it should be possible to propagate automatically the sum aggregation over an array to do member-wise sum, but I didn't find anything related to this in the doc. What are the alternatives to my code?

edit : my traceback :

   ERROR codegen.CodeGenerator: failed to compile: org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Code of method "processNext()V" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator" grows beyond 64 KB
    at org.codehaus.janino.UnitCompiler.compileUnit(UnitCompiler.java:361)
    at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:234)
    at org.codehaus.janino.SimpleCompiler.compileToClassLoader(SimpleCompiler.java:446)
    at org.codehaus.janino.ClassBodyEvaluator.compileToClass(ClassBodyEvaluator.java:313)
    at org.codehaus.janino.ClassBodyEvaluator.cook(ClassBodyEvaluator.java:235)
    at org.codehaus.janino.SimpleCompiler.cook(SimpleCompiler.java:204)
    at org.codehaus.commons.compiler.Cookable.cook(Cookable.java:80)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.org$apache$spark$sql$catalyst$expressions$codegen$CodeGenerator$$doCompile(CodeGenerator.scala:1002)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1069)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$$anon$1.load(CodeGenerator.scala:1066)
    at org.spark_project.guava.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599)
    at org.spark_project.guava.cache.LocalCache$Segment.loadSync(LocalCache.java:2379)
    at org.spark_project.guava.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342)
    at org.spark_project.guava.cache.LocalCache$Segment.get(LocalCache.java:2257)
    at org.spark_project.guava.cache.LocalCache.get(LocalCache.java:4000)
    at org.spark_project.guava.cache.LocalCache.getOrLoad(LocalCache.java:4004)
    at org.spark_project.guava.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874)
    at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:948)
    at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:375)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:97)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$doExecute$1.apply(HashAggregateExec.scala:92)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doExecute(HashAggregateExec.scala:92)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange.prepareShuffleDependency(ShuffleExchange.scala:88)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:124)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange$$anonfun$doExecute$1.apply(ShuffleExchange.scala:115)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    at org.apache.spark.sql.execution.exchange.ShuffleExchange.doExecute(ShuffleExchange.scala:115)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply$mcV$sp(FileFormatWriter.scala:173)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
    at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:166)
    at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:145)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.datasources.DataSource.writeInFileFormat(DataSource.scala:435)
    at org.apache.spark.sql.execution.datasources.DataSource.write(DataSource.scala:471)
    at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:48)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:58)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:56)
    at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:117)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:138)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:135)
    at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:116)
    at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:92)
    at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:92)
    at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:609)
    at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:233)
    at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:217)
    at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:597)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.runAndSaveMetrics(RankingMetricsComputer.scala:286)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer$.main(RankingMetricsComputer.scala:366)
    at com.criteo.enterprise.eligibility_metrics.RankingMetricsComputer.main(RankingMetricsComputer.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$2.run(ApplicationMaster.scala:635)
lezebulon
  • 7,607
  • 11
  • 42
  • 73
  • Sorry I don't have the catalyst exception right now, but it was related to the generated code being too large / using too many variables – lezebulon Oct 16 '18 at 20:33
  • When you have a chance please [edit] your question and attach the traceback. It will help to diagnose the issue, and find possible solution. Also, it would be great if you could include type annotations (what is `dataframe` - `Dataset[_]`, `RelationalGroupedDataset`?). Performance-wise you won't find a better solution than an aggregation anyway. – zero323 Oct 16 '18 at 21:42
  • Great. And what is the type of `dataframe`? – zero323 Oct 18 '18 at 16:49
  • Related [SparkSQL job fails when calling stddev over 1,000 columns](https://stackoverflow.com/q/50425948/6910411). – zero323 Oct 18 '18 at 16:55

1 Answers1

1

The best way to do it is to convert your nested array's into their own rows so you can use a single groupBy. This way you can do it all in one aggregation instead of 100 (or more). The key to doing this is to use posexplode which will turn each entry in the array into a new Row with the index it was positioned in the array.

For example:

import org.apache.spark.sql.functions.{posexplode, collect_list}

val data = Seq(
    (Seq(1, 2, 3, 4, 5)),
    (Seq(2, 3, 4, 5, 6)),
    (Seq(3, 4, 5, 6, 7))
)

val df = data.toDF

val df2 = df.
    select(posexplode($"value")).
    groupBy($"pos").
    agg(sum($"col") as "sum")

// At this point you will have rows with the index and the sum
df2.orderBy($"pos".asc).show

Would output a DataFrame like this:

+---+---+
|pos|sum|
+---+---+
|  0|  6|
|  1|  9|
|  2| 12|
|  3| 15|
|  4| 18|
+---+---+

Or if you want them in one row, you could ad something like this:

df2.groupBy().agg(collect_list(struct($"pos", $"sum")) as "list").show

The values in the Array column wouldn't be sorted, but you could write a UDF to sort it by the pos field, and drop the pos field if you wanted to do that.

Updated per comment

If the above approach doesn't work with whatever other aggregations you are trying to do, then you would need to define your own UDAF. The general idea here is you tell Spark how to combine values for the same key inside a partition to create intermediate values, and then how to combine those intermediate values across partitions to create the final value for each key. Once you define a UDAF class you can use that in the aggs call with any other aggregations you would like to do.

Here is a quick example I knocked out. Note it assumes the array length, and probably should be made more error-proof, but should get you most of the way there.

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction


class ArrayCombine extends UserDefinedAggregateFunction {
  // The input this aggregation will receive (each row)
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", ArrayType(IntegerType)) :: Nil)

  // Your intermediate state as you are updating with data from each row
  override def bufferSchema: StructType = StructType(
    StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = ArrayType(IntegerType)

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = (0 until 100).toArray
  }

  // Given a new input row, update our state
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val sums = buffer.getSeq[Int](0)
    val newVals = input.getSeq[Int](0)

    buffer(0) = sums.zip(newVals).map { case (a, b) => a + b }
  }

  // After we have finished computing intermediate values for each partition, combine the partitions
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val sums1 = buffer1.getSeq[Int](0)
    val sums2 = buffer2.getSeq[Int](0)

    buffer1(0) = sums1.zip(sums2).map { case (a, b) => a + b }
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[Int](0)
  }
}

Then call it like this:

val arrayUdaf = new ArrayCombine()
df.groupBy().agg(arrayUdaf($"value")).show
Ryan Widmaier
  • 7,948
  • 2
  • 30
  • 32
  • `UserDefinedAggregateFunction` like this one is a really bad advice if @lezebulon is looking for a better performance - [Spark UDAF with ArrayType as bufferSchema performance issues](https://stackoverflow.com/q/47293454/6910411), especially with large array. – zero323 Oct 16 '18 at 21:33
  • The linked methods are faster, but if the op needs to run multiple simultaneous aggs, I would try the UDAF first and see how it does on his data before attempting more specialized and less flexible approaches. – Ryan Widmaier Oct 16 '18 at 21:53
  • Thanks a lot, I can't test today anymore but I'll report my result tomorrow – lezebulon Oct 16 '18 at 22:01
  • @user6910411 can't I reuse the answer to the question you linked and `use primitive types in place of ArrayType` for my UDAF ? If so shouldn't it be faster while still keeping a UDAF? – lezebulon Oct 16 '18 at 22:09
  • @lezebulon You might try, though if you hive some Catalyst problems, it might not change a thing, and even if does, it won't be more than workaround. Out of curiosity - does your code work when you skip `array`? – zero323 Oct 16 '18 at 23:58