25

I want to check how can we get information about each partition such as total no. of records in each partition on driver side when Spark job is submitted with deploy mode as a yarn cluster in order to log or print on the console.

zero323
  • 322,348
  • 103
  • 959
  • 935
nilesh1212
  • 1,561
  • 2
  • 26
  • 60

6 Answers6

35

I'd use built-in function. It should be as efficient as it gets:

import org.apache.spark.sql.functions.spark_partition_id

df.groupBy(spark_partition_id).count
Alper t. Turker
  • 34,230
  • 9
  • 83
  • 115
27

You can get the number of records per partition like this :

df
  .rdd
  .mapPartitionsWithIndex{case (i,rows) => Iterator((i,rows.size))}
  .toDF("partition_number","number_of_records")
  .show

But this will also launch a Spark Job by itself (because the file must be read by spark to get the number of records).

Spark could may also read hive table statistics, but I don't know how to display those metadata..

Raphael Roth
  • 26,751
  • 15
  • 88
  • 145
26

For future PySpark users:

from pyspark.sql.functions  import spark_partition_id
rawDf.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()
BishoyM
  • 561
  • 1
  • 7
  • 16
5

Spark/scala:

val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect()  # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length)  # check if skewed

PySpark:

num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect()  # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l))  # check if skewed

The same is possible for a dataframe, not just for an RDD. Just add DF.rdd.glom... into the code above.

Credits: Mike Dusenberry @ https://issues.apache.org/jira/browse/SPARK-17817

Tagar
  • 13,911
  • 6
  • 95
  • 110
3

Spark 1.5 solution :

(sparkPartitionId() exists in org.apache.spark.sql.functions)

import org.apache.spark.sql.functions._ 

df.withColumn("partitionId", sparkPartitionId()).groupBy("partitionId").count.show

as mentioned by @Raphael Roth

mapPartitionsWithIndex is best approach, will work with all version of spark since its RDD based approach

Praveen Sripati
  • 32,799
  • 16
  • 80
  • 117
Ram Ghadiyaram
  • 28,239
  • 13
  • 95
  • 121
0

PySpark:

from pyspark.sql.functions import spark_partition_id

df.select(spark_partition_id().alias("partitionId")).groupBy("partitionId").count()
rwitzel
  • 1,694
  • 17
  • 21