I've noticed that spark's function, collect
is extremely slow on large sets of data so I'm trying to fix this using parallelize.
My main method creates the spark session and passes that to the get_data
func.
def main():
spark = SparkSession.builder.appName('app_name').getOrCreate()
return get_data(spark)
Here is where I try to parallelize my collect function
def get_data(spark):
df = all_data(spark)
data = spark.sparkContext.parallelize(df.select('my_column').distinct().collect())
return map(lambda row: row['my_column'], data)
This does not work and returns this error:
TypeError: 'RDD' object is not iterable
Does anyone have any ideas on how to parallelize or increase performance on the get_data
function.