I want to check how can we get information about each partition such as total no. of records in each partition on driver side when Spark job is submitted with deploy mode as a yarn cluster in order to log or print on the console.
6 Answers
I'd use built-in function. It should be as efficient as it gets:
import org.apache.spark.sql.functions.spark_partition_id
df.groupBy(spark_partition_id).count

- 34,230
- 9
- 83
- 115
-
6You can use `df.withColumn("partition_id", spark_partition_id).groupBy("partition_id").count` for 1.6 – philantrovert Sep 04 '17 at 12:24
You can get the number of records per partition like this :
df
.rdd
.mapPartitionsWithIndex{case (i,rows) => Iterator((i,rows.size))}
.toDF("partition_number","number_of_records")
.show
But this will also launch a Spark Job by itself (because the file must be read by spark to get the number of records).
Spark could may also read hive table statistics, but I don't know how to display those metadata..

- 26,751
- 15
- 88
- 145
For future PySpark users:
from pyspark.sql.functions import spark_partition_id
rawDf.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

- 561
- 1
- 7
- 16
Spark/scala:
val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect() # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length) # check if skewed
PySpark:
num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect() # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l)) # check if skewed
The same is possible for a dataframe
, not just for an RDD
.
Just add DF.rdd.glom
... into the code above.
Credits: Mike Dusenberry @ https://issues.apache.org/jira/browse/SPARK-17817

- 13,911
- 6
- 95
- 110
Spark 1.5 solution :
(sparkPartitionId()
exists in org.apache.spark.sql.functions
)
import org.apache.spark.sql.functions._
df.withColumn("partitionId", sparkPartitionId()).groupBy("partitionId").count.show
as mentioned by @Raphael Roth
mapPartitionsWithIndex
is best approach, will work with all version of spark since its RDD based approach

- 32,799
- 16
- 80
- 117

- 28,239
- 13
- 95
- 121
PySpark:
from pyspark.sql.functions import spark_partition_id
df.select(spark_partition_id().alias("partitionId")).groupBy("partitionId").count()

- 1,694
- 17
- 21