7

I'm looking for the Pyspark equivalent to this question: How to get the number of elements in partition?.

Specifically, I want to programmatically count the number of elements in each partition of a pyspark RDD or dataframe (I know this information is available in the Spark Web UI).

This attempt:

df.foreachPartition(lambda iter: sum(1 for _ in iter))

results in:

AttributeError: 'NoneType' object has no attribute '_jvm'

I do not want to collect the contents of the iterator into memory.

Tshilidzi Mudau
  • 7,373
  • 6
  • 36
  • 49
Matt Frei
  • 379
  • 1
  • 4
  • 13

1 Answers1

14

If you are asking: can we get the number of elements in an iterator without iterating through it? The answer is No.

But we don't have to store it in memory, as in the post you mentioned:

def count_in_a_partition(idx, iterator):
  count = 0
  for _ in iterator:
    count += 1
  return idx, count

data = sc.parallelize([
    1, 2, 3, 4
], 4)

data.mapPartitionsWithIndex(count_in_a_partition).collect()

EDIT

Note that your code is very close to the solution, just that mapPartitions needs to return an iterator:

def count_in_a_partition(iterator):
  yield sum(1 for _ in iterator)

data.mapPartitions(count_in_a_partition).collect()
Community
  • 1
  • 1
shuaiyuancn
  • 2,744
  • 3
  • 24
  • 32
  • Thanks @ShuaiYuan. No, I know that I'll have to iterate through to get the count. Your first solution works for me! However, the second still throws the same AttributeError as my original attempt in Spark 1.5.0 (my organization's cluster), even on the "data" rdd you create in your example. AttributeError: 'NoneType' object has no attribute '_jvm'. However, in Spark Community Edition running 1.6.0 or 1.5.2, both of your solutions work. Perhaps something strange about my local CDH distro? – Matt Frei Aug 15 '16 at 22:50
  • Could be. Unfortunately I don't have a Spark 1.5.0 to test with. – shuaiyuancn Aug 16 '16 at 10:46