Here's an example of PySpark's flatMap
on an RDD:
sc.parallelize([3,4,5]).flatMap(lambda x: range(1,x)).collect()
which will yield
[1, 2, 1, 2, 3, 1, 2, 3, 4]
as opposed to just map
which would yield [[1, 2], [1, 2, 3], [1, 2, 3, 4]]
(for comparison).
flatMap
also only does one level of "unnesting". In other words, if you have a 3d list, it will only flatten it to a 2d list. So, we'll make our flattener do this too.
As alluded to in the comments, all you have to do is call the built-in map
, and create a flattening function, and chain them together. Here's how:
def flatMap(f, li):
mapped = map(f, li)
flattened = flatten_single_dim(mapped)
yield from flattened
def flatten_single_dim(mapped):
for item in mapped:
for subitem in item:
yield subitem
going back to our example as a quick sanity check:
res = flatMap(lambda x: range(1, x), [3,4,5])
print(list(res))
which outputs:
[1, 2, 1, 2, 3, 1, 2, 3, 4]
as desired. You'd do flatMap(lambda tile: process_tile(tile, sample_size, grayscale), filtered_tiles)
(given filtered_tiles
is an iterable).
P.S. As a side note, you can run Spark in "local" mode, and just call flatMap
on RDDs. It'll work just fine for prototyping small stuff on your local machine. Then you can hook into a cluster with some cluster manager when you're ready to scale and have TBs of data you need to rip though.
HTH.