I'm trying to read a table from an RDS MySQL instance using PySpark. It's a huge table, hence I want to parallelize the read operation by making use of the partitioning concept. The table doesn't have a numeric column to find the number of partitions. Instead, it has a timestamp column (i.e. datetime type).
I found the lower and upper bounds by retrieving the min and max values of the timestamp column. However, I'm not sure if there's a standard formula to find out the number of partitions dynamically. Here is what I'm doing currently (hardcoding the value for numPartititons parameter):
select_sql = "SELECT {} FROM {}".format(columns, table)
partition_info = {'partition_column': 'col1',
'lower_bound': '<result of min(col1)>',
'upper_bound': '<result of max(col1)>',
'num_partitions': '10'}
read_df = spark.read.format("jdbc") \
.option("driver", driver) \
.option("url", url) \
.option("dbtable", select_sql) \
.option("user", user) \
.option("password", password) \
.option("useSSL", False) \
.option("partitionColumn", partition_info['partition_column']) \
.option("lowerBound", partition_info['lower_bound'])) \
.option("upperBound", partition_info['upper_bound'])) \
.option("numPartitions", partition_info['num_partitions']) \
.load()
Please suggest me a solution/your approach that works. Thanks