It turned out that the spark has a way of controlling the partitioning logic exactly. And that is the predicates
option in spark.read.jdbc
.
What I came up with eventually is as follows:
(For the sake of the example, imagine that we have the purchase records of a store, and we need to partition it based on userId
and productId
so that all the records of an entity is kept together on the same machine, and we can perform aggregations on these entities without shuffling)
- First, produce the histogram of every column that you want to partition by (count of each value):
userId |
count |
123456 |
1640 |
789012 |
932 |
345678 |
1849 |
901234 |
11 |
... |
... |
productId |
count |
123456789 |
5435 |
523485447 |
254 |
363478326 |
2343 |
326484642 |
905 |
... |
... |
- Then, use the multifit algorithm to divide the values of each column into
n
balanced bins (n being the number of partitions that you want).
userId |
bin |
123456 |
1 |
789012 |
1 |
345678 |
1 |
901234 |
2 |
... |
... |
productId |
bin |
123456789 |
1 |
523485447 |
2 |
363478326 |
2 |
326484642 |
3 |
... |
... |
Then, store these in the database
Then update your query and join
on these tables to get the bin numbers for every record:
url = 'jdbc:oracle:thin:username/password@address:port:dbname'
query = ```
(SELECT
MY_TABLE.*,
USER_PARTITION.BIN as USER_BIN,
PRODUCT_PARTITION.BIN AS PRODUCT_BIN
FROM MY_TABLE
LEFT JOIN USER_PARTITION
ON my_table.USER_ID = USER_PARTITION.USER_ID
LEFT JOIN PRODUCT_PARTITION
ON my_table.PRODUCT_ID = PRODUCT_PARTITION.PRODUCT_ID) MY_QUERY```
df = spark.read\
.option('driver', 'oracle.jdbc.driver.OracleDriver')\
jdbc(url=url, table=query, predicates=predicates)
- And finally, generate the predicates. One for each partition, like these:
predicates = [
'USER_BIN = 1 OR PRODUCT_BIN = 1',
'USER_BIN = 2 OR PRODUCT_BIN = 2',
'USER_BIN = 3 OR PRODUCT_BIN = 3',
...
'USER_BIN = n OR PRODUCT_BIN = n',
]
The predicates are added to the query as WHERE
clauses, which means that all the records of the users in partition 1 go to the same machine. Also, all the records of the products in partition 1 go to that same machine as well.
Note that there are no relations between the user and the product here. We don't care which products are in which partition or are sent to which machine.
But since we want to perform some aggregations on both the users and the products (separately), we need to keep all the records of an entity (user or product) together. And using this method, we can achieve that without any shuffles.
Also, note that if there are some users or products whose records don't fit in the workers' memory, then you need to do a sub-partitioning. Meaning that you should first add a new random numeric column to your data (between 0 and some chunk_size
like 10000 or something), then do the partitioning based on the combination of that number and the original IDs (like userId). This causes each entity to be split into fixed-sized chunks (i.e., 10000) to ensure it fits in the workers' memory.
And after the aggregations, you need to group your data on the original IDs to aggregate all the chunks together and make each entity whole again.
The shuffle at the end is inevitable because of our memory restriction and the nature of our data, but this is the most efficient way you can achieve the desired results.