6

I need to collect partitions/batches from a big pyspark dataframe so that I can feed them into a neural network iteratively

My idea was to 1) partition the data, 2) Iteratively collect each partition, 3) transform the collected partition with toPandas()

I am a bit confused with methods like foreachPartition and mapPartitions because I can't iterate on them. Any idea?

Sociopath
  • 13,068
  • 19
  • 47
  • 75
cadama
  • 359
  • 4
  • 13

1 Answers1

11

You can use the mapPartitions to map each partition into list of elements and get them in iterative way using toLocalIterator:

for partition in rdd.mapPartitions(lambda part: [list(part)]).toLocalIterator():
    print(len(partition)) # or do something else :-)
Mariusz
  • 13,481
  • 3
  • 60
  • 64