I am interested in using Spark SQL (1.6) to perform "filtered equi-joins" of the form
A inner join B where A.group_id = B.group_id and pair_filter_udf(A[cols], B[cols])
Here the group_id
is coarse: a single value of group_id
could be associated with, say, 10,000 records in both A and B.
If the equi-join were performed by itself, without the pair_filter_udf
, the coarseness of group_id
would create computational issues. For example, for a group_id
with 10,000 records in both A and B, there would be 100 million entries in the join. If we had many thousands of such large groups, we would generate an enormous table and we could very easily run out of memory.
Thus, it is essential that we push down pair_filter_udf
into the join and have it filter pairs as they are generated, rather than waiting until all pairs have been generated. My question is whether Spark SQL does this.
I set up a simple filtered equi-join and asked Spark what its query plan was:
# run in PySpark Shell
import pyspark.sql.functions as F
sq = sqlContext
n=100
g=10
a = sq.range(n)
a = a.withColumn('grp',F.floor(a['id']/g)*g)
a = a.withColumnRenamed('id','id_a')
b = sq.range(n)
b = b.withColumn('grp',F.floor(b['id']/g)*g)
b = b.withColumnRenamed('id','id_b')
c = a.join(b,(a.grp == b.grp) & (F.abs(a['id_a'] - b['id_b']) < 2)).drop(b['grp'])
c = c.sort('id_a')
c = c[['grp','id_a','id_b']]
c.explain()
Result:
== Physical Plan ==
Sort [id_a#21L ASC], true, 0
+- ConvertToUnsafe
+- Exchange rangepartitioning(id_a#21L ASC,200), None
+- ConvertToSafe
+- Project [grp#20L,id_a#21L,id_b#24L]
+- Filter (abs((id_a#21L - id_b#24L)) < 2)
+- SortMergeJoin [grp#20L], [grp#23L]
:- Sort [grp#20L ASC], false, 0
: +- TungstenExchange hashpartitioning(grp#20L,200), None
: +- Project [id#19L AS id_a#21L,(FLOOR((cast(id#19L as double) / 10.0)) * 10) AS grp#20L]
: +- Scan ExistingRDD[id#19L]
+- Sort [grp#23L ASC], false, 0
+- TungstenExchange hashpartitioning(grp#23L,200), None
+- Project [id#22L AS id_b#24L,(FLOOR((cast(id#22L as double) / 10.0)) * 10) AS grp#23L]
+- Scan ExistingRDD[id#22L]
These are the key lines from the plan:
+- Filter (abs((id_a#21L - id_b#24L)) < 2)
+- SortMergeJoin [grp#20L], [grp#23L]
These lines gives the impression that the filter will be done in a separate stage after the join, which is not the desired behavior. But maybe it's being implicitly pushed down into the join, and the query plan just lacks that level of detail.
How can I tell what Spark is doing in this case?
Update:
I'm running experiments with n=1e6 and g=1e5, which should be enough to crash my laptop if Spark is not doing pushdown. Since it is not crashing, I guess it is doing pushdown. But it would be interesting to know how it works and what parts of the Spark SQL source are responsible for this awesome optimization.