As explained in the link by cricket_007,
When using combineByKey
values are merged into one value at each partition then each partition value is merged into a single value.
Lets first look at the number of partitions and what each partition contains after we parallelize
the data.
>>> data = [ ("B", 2), ("A", 1), ("A", 4), ("B", 2), ("B", 3) ]
>>> rdd = sc.parallelize( data )
>>> rdd.collect()
[('B', 2), ('A', 1), ('A', 4), ('B', 2), ('B', 3)]
Number of partitions (by default):
>>> num_partitions = rdd.getNumPartitions()
>>> print(num_partitions)
4
Contents of each partition:
>>> partitions = rdd.glom().collect()
>>> for num,partition in enumerate(partitions):
... print(f'Partitions {num} -> {partition}')
Partitions 0 -> [('B', 2)]
Partitions 1 -> [('A', 1)]
Partitions 2 -> [('A', 4)]
Partitions 3 -> [('B', 2), ('B', 3)]
combineByKey
is defined as
combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner)
The three functions that combineByKey
takes as arguments,
createCombiner :(lambda value: (value, value+2, 1)
This will be called on every unseen key in a partition.
mergeValue : lambda x, value: (x[0] + value, x[1] + value*value, x[2] + 1)
This will be called when the key is already seen before in a particular partition.
mergeCombiners : lambda x, y: (x[0] + y[0], x[1] + y[1], x[2] + y[2])
This will be called to merge the keys of different partitions
- partitioner : Beyond the scope of this answer.
Now let's work out what happens:
Parition 0: [('B', 2)]
createCombiner
('B', 2) -> Unseen Key -> ('B', (2, 2+2, 1))
-> ('B', (2,4,1)
# Same createCombiner for partition 1,2,3
Partition 1: [('A',1)]
createCombiner
('A',1) -> Unseen Key -> ('A', (1,3,1))
Partition 2: [('A',4)]
createCombiner
('A',4) -> Unseen Key -> ('A', (4,6,1))
Partition 3: [('B',2), ('B',3)]
createCombiner
('B',2) -> Unseen Key -> ('B',(2,4,1))
('B',3) -> Seen Key -> mergeValue ('B',(2,4,1)) with ('B',3)
-> ('B', (2 + 3, 4+(3*3), 1+1)
-> ('B', (5,13,2))
Partition 0 and Partition 3:
mergeCombiners ('B', (2,4,1)) and ('B', (5,13,2))
-> ('B', (2+5,4+13,1+2))
-> ('B', (7,19,3)
Partition 1 and 2:
mergeCombiners ('A', (1,3,1)) and ('A', (4,6,1))
-> ('A', (1+4, 3+6, 1+1))
-> ('A', (5,9,2))
So the final answer that we get is:
>>> rdd2 = rdd.combineByKey(lambda value: (value, value+2, 1),
... lambda x, value: (x[0] + value, x[1] + value*value, x[2] + 1),
... lambda x, y: (x[0] + y[0], x[1] + y[1], x[2] + y[2]))
>>> rdd2.collect()
[('B', (7, 17, 3)), ('A', (5, 9, 2))]
I hope this explains whats going on.
Additional Clarification as asked in comments:
How does spark set the number of partitions?
From the docs: Spark tries to set the number of partitions automatically based on your cluster. However, you can also set it manually by passing it as a second parameter to parallelize (e.g. sc.parallelize(data, 10)
How does spark partition the data?
A partition (aka split) is a logical chunk of a large distributed data set.
Spark has three different partitioning schemes, namely
- hashPartitioner : The Default. Send keys with the same hash module end up on the same node.
- customPartitioner :Example below.
- rangePartitioner : Elements with keys in the same range appear on the same node.
I quote from Learning Spark by Karau et al. Pg.61, that spark does not give you explicit control on which key goes to which partition, but it ensures a set of keys will appear together on some node. If you want keys with the same value to appear together in the same partition you can use a custom partitioner like so.
>>> def customPartitioner(key):
... if key == 'A':
... return 0
... if key == 'B':
... return 1
>>> num_partitions = 2
>>> rdd = sc.parallelize( data ).partitionBy(num_partitions,customPartitioner)
>>> partitions = rdd.glom().collect()
>>> for num,partition in enumerate(partitions):
... print(f'Partition {num} -> {partition}')
Partition 0 -> [('A', 1), ('A', 4)]
Partition 1 -> [('B', 2), ('B', 2), ('B', 3)]
I encourage you to read the book to learn more.