It is possible to handle this without shuffling (groupBy
) but it requires a little bit more code compared to solutions by Olologin and Rohan Aletty. A whole idea is to transfer only the parts required to keep continuity between partitions:
from toolz import partition, drop, take, concatv
def grouped(self, n, pad=None):
"""
Group RDD into tuples of size n
>>> rdd = sc.parallelize(range(10))
>>> grouped(rdd, 3).collect()
>>> [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, None, None)]
"""
assert isinstance(n, int)
assert n > 0
def _analyze(i, iter):
"""
Given partition idx and iterator return a tuple
(idx, numbe-of-elements prefix-of-size-(n-1))
"""
xs = [x for x in iter]
return [(i, len(xs), xs[:n - 1])]
def _compact(prefixes, prefix):
"""
'Compact' a list of prefixes to compensate for
partitions with less than (n-1) elements
"""
return prefixes + [(prefix + prefixes[-1])[:n-1]]
def _compute(prvs, cnt):
"""
Compute number of elements to drop from current and
take from the next parition given previous state
"""
left_to_drop, _to_drop, _to_take = prvs[-1]
diff = cnt - left_to_drop
if diff <= 0:
return prvs + [(-diff, cnt, 0)]
else:
to_take = (n - diff % n) % n
return prvs + [(to_take, left_to_drop, to_take)]
def _group_partition(i, iter):
"""
Return grouped entries for a given partition
"""
(_, to_drop, to_take), next_head = heads_bd.value[i]
return partition(n, concatv(
drop(to_drop, iter), take(to_take, next_head)), pad=pad)
if n == 1:
return self.map(lambda x: (x, ))
idxs, counts, prefixes = zip(
*self.mapPartitionsWithIndex(_analyze).collect())
heads_bd = self.context.broadcast({x[0]: (x[1], x[2]) for x in zip(idxs,
reduce(_compute, counts, [(0, None, None)])[1:],
reduce(_compact, prefixes[::-1], [[]])[::-1][1:])})
return self.mapPartitionsWithIndex(_group_partition)
It depends heavily heavily on a great toolz
library but if you prefer to avoid external dependencies you can easily rewrite it using standard library.
Example usage:
>>> rdd = sc.parallelize(range(10))
>>> grouped(rdd, 3).collect()
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, None, None)]
If you want to keep an consistent API you can monkey-patch RDD class:
>>> from pyspark.rdd import RDD
>>> RDD.grouped = grouped
>>> rdd.grouped(4).collect()
[(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, None, None)]
You can find basic tests on GitHub.