I am using an inner join to generate record comparisons, for the purpose of deduplicating data.
I would like to salt these joins so that record comparisons are more equally distributed in the presence of skew.
What follows is a very simple motivating example - the real input data is much larger.
Suppose we have a table as follows (csv here).
first_name | surname | city |
---|---|---|
charles | dickens | london |
charlie | dickens | london |
virginia | woolf | london |
virginia | wolf | london |
mary | shelley | london |
jane | austen | steventon |
To generate record comparisons I can write sql like:
df = spark.read.csv()
df.createOrReplaceTempView("df")
sql = """
select
l.first_name as first_name_l,
r.first_name as first_name_r,
l.surname as surname_l,
r.surname as surname_r,
l.city as city_l,
r.city as city_r
from df as l
inner join df as r
on l.city = r.city
"""
spark.sql(sql)
On a large dataset, spark will chose a SortMergeJoin
. The data will be HashPartitioned on city
.
All 5 records with city = london
will therefore end up on a single executor, on which the cartesian product of the london
records will be produced - 25 records in total.
This creates a problem on real datasets where the count of city=london
may be 10,000 - generating 100,000,000 comparisons in a single task on a single executor.
My question is how can I salt this join to split up the work more evenly?. Note that all 25 (or 100m) record comparisons need to be generated - it's just we want them to be spit between different tasks
Solutions I've attempted
I have a working solution that's very inelegant, as follows. I'm looking to improve on this.
Step 1:
Create a random integer column, random_int
. For simplicity, let's say this contains integers in the range 1-3.
Step 2: Run the left join three times, and union all
select {cols}
from df as l
inner join df as r
on l.city = r.city and l.random_int = 1
UNION ALL
select {cols}
from df as l
inner join df as r
on l.city = r.city and l.random_int = 2
UNION ALL
select {cols}
from df as l
inner join df as r
on l.city = r.city and l.random_int = 3
This solution gives the right answer, and does run faster on large datasets in the presence of skew. But it creates a lot of complexity on the execution plan, and I can't help feeling there must be a better way.
This real context problem is the blocking step of my open source software, Splink. So any help provided will help improve this software. (PRs are of course welcome as well!)