2

Problem: I am receiving multiple table/schema data in a single stream. Now after segregating the data I am opening a parallel write stream for each table.

The function I used in forEachBatch is:

def writeToAurora(df, batch_id, tableName):
    df = df.persist()
    stagingTable = f'{str(tableName.lower())}_delta'
    
    df.write \
            .mode("overwrite") \
            .format("jdbc") \
            .option("truncate", "true") \
            .option("driver", DB_conf['DRIVER']) \
            .option("batchsize", 1000) \
            .option("url", DB_conf['URL']) \
            .option("dbtable", stagingTable) \
            .option("user", DB_conf['USER_ID']) \
            .option("password", DB_conf['PASSWORD']) \
            .save() 
    df.unpersist()

The logic to open multiple writestreams is

data_df = spark.readStream.format("kinesis") \
    .option("streamName", stream_name) \
    .option("startingPosition", initial_position) \
    .load()

                 
#Distinguishing table wise df
distinctTables = ['Table1', 'Table2', 'Table3']
tablesDF = {table: data_df.filter(f"TableName = '{table}'") for table in distinctTables}


#Processing Each Table
for table, tableDF in tablesDF.items():
    df = tableDF.withColumn('csvData', F.from_csv('finalData', schema=tableSchema[table], options={'sep': '|','quote': '"'}))\
        .select('csvData.*')


    vars()[table+'_query'] = df.writeStream\
                .trigger(processingTime='120 seconds') \
                .foreachBatch(lambda fdf, batch_id: writeToAurora(fdf, batch_id, table)) \
                .option("checkpointLocation", f"s3://{bucket}/temporary/checkpoint/{table}")\
                .start()
                
                
for table in tablesDF.keys():
    eval(table+'_query').awaitTermination()

Issue: Now when running the above code sometimes the table1 is getting loaded in table2 and the order is differenet each time the code runs. The order is not maintained between the dataframe and the table in which it should be loaded.

Need help on understanding why this is happening.

Shubham Jain
  • 5,327
  • 2
  • 15
  • 38

2 Answers2

3

This is caused by late binding for your lambda function in the foreachBatch method.

Here's an example. This will try and write all tables to "t2", and fails (actually only writing the "t2" table, but writing the "t0" data:

from pyspark.sql.functions import *
from pyspark.sql import *

def writeToTable(df, epochId, table_name):
  df.write.mode("overwrite").saveAsTable(f"custanwo.dsci.stream_test_{table_name}")

data_df = spark.readStream.format("rate").load()
data_df = (data_df
           .selectExpr("value % 10 as key")
           .groupBy("key")
           .count()
           .withColumn("t", concat(lit("t"),(col("key")%3).astype("string")))
)

table_names = ["t0", "t1", "t2"]
table_df = {t: data_df.filter(f"t = '{t}'") for t in table_names}

for t, df in table_df.items():
    vars()[f"{t}_query"] = (df
                            .writeStream
                            .foreachBatch(lambda df, epochId: writeToTable(df, epochId, t))
                            .outputMode("update")
                            .start()
                            )

To resolve this there are a few options. One is using partial:

from functools import partial

def writeToTable(df, epochId, table_name):
  df.write.mode("overwrite").saveAsTable(f"custanwo.dsci.stream_test_{table_name}")

data_df = spark.readStream.format("rate").load()
data_df = (data_df
           .selectExpr("value % 10 as key")
           .groupBy("key")
           .count()
           .withColumn("t", concat(lit("t"),(col("key")%3).astype("string")))
) 

table_names = ["t0", "t1", "t2"]
table_df = {t: data_df.filter(f"t = '{t}'") for t in table_names}

for t, df in table_df.items():
    vars()[f"{t}_query"] = (df
                            .writeStream
                            .foreachBatch(partial(writeToTable, table_name=t))
                            .outputMode("update")
                            .start()
                            )

In your code, rewrite your writeStream to:

    vars()[table+'_query'] = df.writeStream\
                .trigger(processingTime='120 seconds') \
                .foreachBatch(partial(writeToAurora, tableName = table)) \
                .option("checkpointLocation", f"s3://{bucket}/temporary/checkpoint/{table}")\
                .start()
s_pike
  • 1,710
  • 1
  • 10
  • 22
-1
def writeToAurora(df, batch_id, tableName):
    df = df.withColumn("TableName", F.lit(tableName))  # Add the TableName column to the DataFrame
    df = df.persist()
    stagingTable = f'{str(tableName.lower())}_delta'
    
    df.write \
        .mode("overwrite") \
        .format("jdbc") \
        .option("truncate", "true") \
        .option("driver", DB_conf['DRIVER']) \
        .option("batchsize", 1000) \
        .option("url", DB_conf['URL']) \
        .option("dbtable", stagingTable) \
        .option("user", DB_conf['USER_ID']) \
        .option("password", DB_conf['PASSWORD']) \
        .save() 
    df.unpersist()

With this change, the DataFrame df passed to the writeToAurora function will now have an additional column named "TableName" containing the name of the table for which the data belongs. The writeToAurora function will then use this information to write the data to the appropriate staging table in Aurora.

ARCHER
  • 92
  • 3