14

The following Spark code correctly demonstrates what I want to do and generates the correct output with a tiny demo data set.

When I run this same general type of code on a large volume of production data, I am having runtime problems. The Spark job runs on my cluster for ~12 hours and fails out.

Just glancing at the code below, it seems inefficient to explode every row, just to merge it back down. In the given test data set, the fourth row with three values in array_value_1 and three values in array_value_2, that will explode to 3*3 or nine exploded rows.

So, in a larger data set, a row with five such array columns, and ten values in each column, would explode out to 10^5 exploded rows?

Looking at the provided Spark functions, there are no out of the box functions that would do what I want. I could supply a user-defined-function. Are there any speed drawbacks to that?

val sparkSession = SparkSession.builder.
  master("local")
  .appName("merge list test")
  .getOrCreate()

val schema = StructType(
  StructField("category", IntegerType) ::
    StructField("array_value_1", ArrayType(StringType)) ::
    StructField("array_value_2", ArrayType(StringType)) ::
    Nil)

val rows = List(
  Row(1, List("a", "b"), List("u", "v")),
  Row(1, List("b", "c"), List("v", "w")),
  Row(2, List("c", "d"), List("w")),
  Row(2, List("c", "d", "e"), List("x", "y", "z"))
)

val df = sparkSession.createDataFrame(rows.asJava, schema)

val dfExploded = df.
  withColumn("scalar_1", explode(col("array_value_1"))).
  withColumn("scalar_2", explode(col("array_value_2")))

// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")

val dfOutput = dfExploded.groupBy("category").agg(
  collect_set("scalar_1").alias("combined_values_2"),
  collect_set("scalar_2").alias("combined_values_2"))

dfOutput.show()
zero323
  • 322,348
  • 103
  • 959
  • 935
clay
  • 18,138
  • 28
  • 107
  • 192

1 Answers1

34

It could be inefficient to explode but fundamentally the operation you try to implement is simply expensive. Effectively it is just another groupByKey and there is not much you can do here to make it better. Since you use Spark > 2.0 you could collect_list directly and flatten:

import org.apache.spark.sql.functions.{collect_list, udf}

val flatten_distinct = udf(
  (xs: Seq[Seq[String]]) => xs.flatten.distinct)

df
  .groupBy("category")
  .agg(
    flatten_distinct(collect_list("array_value_1")), 
    flatten_distinct(collect_list("array_value_2"))
  )

In Spark >= 2.4 you can replace udf with composition of built-in functions:

import org.apache.spark.sql.functions.{array_distinct, flatten}

val flatten_distinct = (array_distinct _) compose (flatten _)

It is also possible to use custom Aggregator but I doubt any of these will make a huge difference.

If sets are relatively large and you expect significant number of duplicates you could try to use aggregateByKey with mutable sets:

import scala.collection.mutable.{Set => MSet}

val rdd = df
  .select($"category", struct($"array_value_1", $"array_value_2"))
  .as[(Int, (Seq[String], Seq[String]))]
  .rdd

val agg = rdd
  .aggregateByKey((MSet[String](), MSet[String]()))( 
    {case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
    {case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
  )
  .mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
  .toDF
zero323
  • 322,348
  • 103
  • 959
  • 935
  • 2
    The first solution of a simple flatten udf totally fixed the issue. Spark went from taking ~12 hours before failing out to completing the whole job successfully in 30 minutes. Watching the Spark monitor GUI, each of the internal tasks run and complete in a minute or less. Thanks for the help on this. – clay Sep 14 '16 at 20:24
  • 1
    I am glad to hear that although I have to admit I am surprised. I expected a small improvement but nothing so impressive. How large are individual lists? – zero323 Sep 14 '16 at 20:29
  • You saved me hours of searching... Thanks a lot! – Mike Reiche May 10 '20 at 11:22