I have two dataframes that I need to join together using a non-equi-join (i.e. an inequality join) that has two join predicates.
One dataframe is a histogram DataFrame[bin: bigint, lower_bound: double, upper_bound: double]
The other dataframe is a collection of observations DataFrame[id: bigint, observation: double]
I need to determine which bin of my histogram each observation falls into, like so:
observations_df.join(histogram_df,
(
(observations_df.observation >= histogram_df.lower_bound) &
(observations_df.observation < histogram_df.upper_bound)
)
)
Basically it is very slow and I'm looking for some suggestions as to how I can make it go quicker.
Below is some sample code the demonstrates the problem. observations_df
contains 100000 rows, when the number of rows in histogram_df
becomes suitably large (let's say number_of_bins = 500000
) then it becomes very very slow and I'm certain its because I'm doing a non-equi-join. If you run this code then play around with the value of number_of_rows
, start with something low and then increase until the slow performance is noticeable
from pyspark.sql.functions import lit, col, lead
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import rand
from pyspark.sql import Window
spark = SparkSession \
.builder \
.getOrCreate()
number_of_bins = 500000
bin_width = 1.0 / number_of_bins
window = Window.orderBy('bin')
histogram_df = spark.range(0, number_of_bins)\
.withColumnRenamed('id', 'bin')\
.withColumn('lower_bound', 0 + lit(bin_width) * col('bin'))\
.select('bin', 'lower_bound', lead('lower_bound', 1, 1.0).over(window).alias('upper_bound'))
observations_df = spark.range(0, 100000).withColumn('observation', rand())
observations_df.join(histogram_df,
(
(observations_df.observation >= histogram_df.lower_bound) &
(observations_df.observation < histogram_df.upper_bound)
)
).groupBy('bin').count().head(15)