18

How can a DataFrame be partitioned based on the count of the number of items in a column. Suppose we have a DataFrame with 100 people (columns are first_name and country) and we'd like to create a partition for every 10 people in a country.

If our dataset contains 80 people from China, 15 people from France, and 5 people from Cuba, then we'll want 8 partitions for China, 2 partitions for France, and 1 partition for Cuba.

Here is code that will not work:

  • df.repartition($"country"): This will create 1 partition for China, one partition for France, and one partition for Cuba
  • df.repartition(8, $"country", rand): This will create up to 8 partitions for each country, so it should create 8 partitions for China, but the France & Cuba partitions are unknown. France could be in 8 partitions and Cuba could be in up to 5 partitions. See this answer for more details.

Here's the repartition() documentation:

repartition documentation

When I look at the repartition() method, I don't even see a method that takes three arguments, so looks like some of this behavior isn't documented.

Is there any way to dynamically set the number of partitions for each column? It would make creating partitioned data sets way easier.

Powers
  • 18,150
  • 10
  • 103
  • 108

2 Answers2

15

You're not going to be able to exactly accomplish that due to the way spark partitions data. Spark takes the columns you specified in repartition, hashes that value into a 64b long and then modulo the value by the number of partitions. This way the number of partitions is deterministic. The reason why it works this way is that joins need matching number of partitions on the left and right side of a join in addition to assuring that the hashing is the same on both sides.

"we'd like to create a partition for every 10 people in a country."

What exactly are you trying to accomplish here? Having only 10 rows in a partition is likely terrible for performance. Are you trying to create a partitioned table where each of the files in the partition is guarunteed to only have x number of rows?

"df.repartition($"country"): This will create 1 partition for China, one partition for France, and one partition for Cuba"

This will actually create a dataframe with the default number of shuffle partitions hashed by country

  def repartition(partitionExprs: Column*): Dataset[T] = {
    repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*)
  }

"df.repartition(8, $"country", rand): This will create up to 8 partitions for each country, so it should create 8 partitions for China, but the France & Cuba partitions are unknown. France could be in 8 partitions and Cuba could be in up to 5 partitions. See this answer for more details."

Like wise this is subtly wrong. There's only 8 partitions with the countries essentially randomly shuffled among those 8 partitions.

Andrew Long
  • 863
  • 4
  • 9
  • Thanks for pointing out my subtle errors. For 10 rows, this code isn't needed, but this is really important when creating partitioned data lakes on large datasets that are skewed. – Powers Oct 19 '19 at 19:06
  • @Andrew Long there is no "sessionState" in sparkSession , where do we have "sessionState" ?? – Shasu Oct 05 '22 at 02:06
  • @Shasu are you using an older version of spark 1.x? – Andrew Long Mar 15 '23 at 20:36
  • @AndrewLong no i am using spark 2.4.5 and now 3.3.1 – Shasu Apr 15 '23 at 16:59
  • spark has hidden session state that is private to spark. https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala – Andrew Long Jun 21 '23 at 23:06
3

Here's the code that'll create ten rows per data file (sample dataset is here):

val outputPath = new java.io.File("./tmp/partitioned_lake5/").getCanonicalPath
df
  .repartition(col("person_country"))
  .write
  .option("maxRecordsPerFile", 10)
  .partitionBy("person_country")
  .csv(outputPath)

Here's the pre Spark 2.2 code that'll create roughly ten rows per data file:

val desiredRowsPerPartition = 10

val joinedDF = df
  .join(countDF, Seq("person_country"))
  .withColumn(
    "my_secret_partition_key",
    (rand(10) * col("count") / desiredRowsPerPartition).cast(IntegerType)
  )

val outputPath = new java.io.File("./tmp/partitioned_lake6/").getCanonicalPath
joinedDF
  .repartition(col("person_country"), col("my_secret_partition_key"))
  .drop("count", "my_secret_partition_key")
  .write
  .partitionBy("person_country")
  .csv(outputPath)
Machavity
  • 30,841
  • 27
  • 92
  • 100
Powers
  • 18,150
  • 10
  • 103
  • 108