2

I have loaded a csv file. Re-partitioned it to 4 and then took count of the DataFrame. And when I looked at the DAG I see this action is executed in 3 stages.

enter image description here

Why this simple action is executed into 3 stages. I suppose 1st stage is to load the file and 2nd is to find the count on each partition.

So What is happening in the 3rd stage?

Here is my code

val sample = spark.read.format("csv").option("header", "true").option("inferSchema", "true").option("delimiter", ";").load("sample_data.csv")

sample.repartition(4).count()
aminography
  • 21,986
  • 13
  • 70
  • 74
ѕтƒ
  • 3,547
  • 10
  • 47
  • 78

1 Answers1

4
  1. The first stage = read a file. Because of repartition (since it's wide transformation that requires shuffling) it can't be joined into single stage with partial_count (2nd stage)

  2. The second stage = local count (calculating count per partition)

  3. The third stage = results aggregation on driver.

Spark generage separate stage per action or wide transformation. To get more details about narrow/wide transformations and why wide transformation require separate stage take a look at "Wide Versus Narrow Dependencies, High Performance Spark, Holden Karau" or this article.

Let's test this assumption locally. First you need create a dataset:

dataset/test-data.json

[
  { "key":  1, "value":  "a" },
  { "key":  2, "value":  "b" },
  { "key":  3, "value":  "c" },
  { "key":  4, "value":  "d" },
  { "key":  5, "value":  "e" },
  { "key":  6, "value":  "f" },
  { "key":  7, "value":  "g" },
  { "key":  8, "value":  "h" }
]

Than run the following code:

    StructType schema = new StructType()
            .add("key", DataTypes.IntegerType)
            .add("value", DataTypes.StringType);

    SparkSession session = SparkSession.builder()
            .appName("sandbox")
            .master("local[*]")
            .getOrCreate();

    session
            .read()
            .schema(schema)
            .json("file:///C:/<you_path>/dataset")
            .repartition(4) // comment on the second run
            .registerTempTable("df");

    session.sqlContext().sql("SELECT COUNT(*) FROM df").explain();

The output will be:

== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(1)])
+- Exchange SinglePartition
   +- *(2) HashAggregate(keys=[], functions=[partial_count(1)])
      +- Exchange RoundRobinPartitioning(4)
         +- *(1) FileScan json [] Batched: false, Format: JSON, Location: InMemoryFileIndex[file:/C:/Users/iaroslav/IdeaProjects/sparksandbox/src/main/resources/dataset], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>

But if you comment/remove .repartition(4) string, note that TableScan & partial_count are done within the single stage and the output will be as following:

== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(1)])
+- Exchange SinglePartition
   +- *(1) HashAggregate(keys=[], functions=[partial_count(1)])
      +- *(1) FileScan json [] Batched: false, Format: JSON, Location: InMemoryFileIndex[file:/C:/Users/iaroslav/IdeaProjects/sparksandbox/src/main/resources/dataset], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>

P.S. Note that extra stage might have a significant impact on performance, since it requires disk I/O (take a look here) and is some kind of synch barrier impacting parallelization, means in most cases Spark won't start stage 2 till stage 1 is completed. Still if repartition increase level of parallelism it probably worth it.

VB_
  • 45,112
  • 42
  • 145
  • 293
  • Thanks for the answer. It makes sense. So This final result aggregation task will always supposed to be run on Driver? What if I take max() of a field in my DF. So after the max value found on each partition, the final value is found at the driver? – ѕтƒ Nov 07 '19 at 04:38
  • @ѕтƒ yes, `count` will always pull results to driver. I think `max` may do the same, otherwise how it can compare results among partitions. But the point here I think is that you don't have to optimize `count` operation here, since only one integer per partition is moved to driver. The most heavy operation here is `repartition`, and if you initially have more than 4 partitions, take a look at `coalesce` or maybe don't repartition at all (if number of partitions is reasonable). Let say you have 2 partitions, maybe `count` without repartitioning may be faster than `repartition(4).count()`. – VB_ Nov 07 '19 at 09:48
  • @ѕтƒ If to speak about further optimizations (suppose you have controll on CSV file creation) you can look at different file formats (columnar Parquet, for `count` operation it's enough to read single column inbstead of the whole file). Also check compression options, if you want to preceed with CSV - than you need splittable compression options. – VB_ Nov 07 '19 at 09:57
  • @ѕтƒ but this is all theory, while in practice `count` operation is pretty lightweigh and don't require those optimizations. Another interesting point is that some data formats contains metadata, also Spark has option of Cost Based Optimizer. That means that with some formats (not in CSV) or with CBO, for `count` operation Spark won't read the data but only metadata instead. In such a case `count` should be very fast, and `repartition` IMHO wouldn't be needed – VB_ Nov 07 '19 at 10:01
  • Thanks for the valuable comments. I'm not looking for a performance improvement. All I want to learn is how spark will co-ordinate the result from different executors and aggregate it to show the final result if its count or max or some other action. I want to know who will aggregate the data at driver side after the driver gets independent result from each partitions. From your answer its an other task which will be running on the driver node. – ѕтƒ Nov 07 '19 at 12:28