I am still learning Scala and Spark, so I apologize if this is a basic question. I have looked through quite a few posts on StackOverflow and Google, and cannot find an answer to this yet. I have the following code -
import org.apache.spark.sql.{Column, SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import java.time.LocalDate
import java.time.format.DateTimeFormatter
import java.sql.Timestamp
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import java.sql.Timestamp
import scala.collection.{Iterable, Map}
val DATE_PATTERN = "yyyy-MM-dd"
val start_day = LocalDate.parse("2023-01-01", DateTimeFormatter.ofPattern(DATE_PATTERN))
val end_day = LocalDate.parse("2023-01-06", DateTimeFormatter.ofPattern(DATE_PATTERN))
val start_day_literal = lit(Timestamp.valueOf(start_day.atStartOfDay())).cast("timestamp")
val end_day_literal = lit(Timestamp.valueOf(end_day.atStartOfDay())).cast("timestamp")
val t_dates = spark.read
.parquet("s3://my-bucket/calendar_dates/")
.filter(col(calendar_day) >= start_day_literal)
.filter(col(calendar_day) <= end_day_literal)
.select(calendar_day)
t_dates.printSchema
t_dates.show(100, false)
// output
root
|-- calendar_day: timestamp (nullable = true)
+-------------------+
|calendar_day |
+-------------------+
|2023-01-01 00:00:00|
|2023-01-02 00:00:00|
|2023-01-03 00:00:00|
|2023-01-04 00:00:00|
|2023-01-05 00:00:00|
|2023-01-06 00:00:00|
+-------------------+
In the above, t_dates
is just a table with a bunch of dates starting from Jan 1, 1990 to Dec 31, 2050. Now the code continues like this -
val MY_TABLE = "reporting.customers"
val my_table_alias = "mt"
val d_dates_alias = "dt"
val customer_id = "customer_id"
val customer_type_id = "customer_type_id"
val customer_role_id = "customer_role_id"
val calendar_day = "calendar_day"
val start_date = "start_date" // in 'yyyy-MM-dd HH:mm:ss' format
val start_date_truncated = "start_date_truncated" // truncated start_date to just 'yyyy-MM-dd' format
val end_date = "end_date"
val end_date_nvl = "end_date_nvl" // if customer is still active, this end_date_nvl will be set to Dec 31, 2050; otherwise, it's equal to end_date, which is when customer stops his/her membership
val is_owner = "is_owner"
val is_guest = "is_guest"
val is_active = "is_active"
def getAliasedColumnName(alias: String, columnName: String): String = {
alias.concat(".").concat(columnName)
}
// This is the meat of the code that I have questions about
val output_df = spark
.table(MY_TABLE)
.withColumn(customer_id, col(beneficiary_external_id).cast("long"))
.withColumn(customer_type_id, upper(col(customer_type_id)))
.withColumn(customer_role_id, upper(col(customer_role_id)))
.filter(col(customer_type_id).isin("C"))
.filter(col(customer_role_id).isin("O", "G"))
.withColumn(is_owner, when(col(customer_role_id) === lit("O"), lit("Y")).otherwise(lit("N")))
.withColumn(is_guest, when(col(customer_role_id) === lit("G"), lit("Y")).otherwise(lit("N")))
.withColumn(start_date_truncated, col(start_date).cast("date"))
.withColumn(end_date_nvl, expr("nvl(cast(end_date as date), to_timestamp('2050-12-31', 'yyyy-MM-dd'))"))
.as(my_table_alias)
.join(
broadcast(t_dates).as(d_dates_alias),
joinExprs = col(getAliasedColumnName(d_dates_alias, calendar_day)).between(
lowerBound = col(getAliasedColumnName(my_table_alias, start_date_truncated)),
upperBound = col(getAliasedColumnName(my_table_alias, end_date_nvl))
),
joinType = "inner"
)
.select(
col(getAliasedColumnName(d_dates_alias, calendar_day)),
col(getAliasedColumnName(my_table_alias, customer_id)),
col(getAliasedColumnName(my_table_alias, start_date_truncated)),
col(getAliasedColumnName(my_table_alias, end_date_nvl)),
col(getAliasedColumnName(my_table_alias, is_owner)),
col(getAliasedColumnName(my_table_alias, is_guest)),
col(getAliasedColumnName(my_table_alias, is_active))
)
output_df.show(100, false);
output_df.count();
// output
+-------------------+------------+----------------+-------------------+--------+--------+---------+
|calendar_day |customer_id |start_date_trunc|end_date_nvl |is_owner|is_guest|is_active|
+-------------------+------------+----------------+-------------------+--------+--------+---------+
|2023-01-01 00:00:00|150634248002|2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
|2023-01-02 00:00:00|150634248002|2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
|2023-01-03 00:00:00|150634248002|2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
|2023-01-04 00:00:00|150634248002|2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
|2023-01-05 00:00:00|150634248002|2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
|2023-01-06 00:00:00|150634248002|2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
|2023-01-01 00:00:00|1067715662 |2019-12-25 |2050-12-31 00:00:00|Y |N |Y |
... the data continues. There are a total of 800 million rows in this table
My questions are -
Q1: Suppose t_dates.calendar_day
is just 1-3 days at most (e.g., Jan 1 to Jan 3, 2023), how do I turn the INNER JOIN
in the above code to something like .filter( (col("dt.calendar_day") >= col("mt.start_date_truncated")) && (col("dt.calendar_day") <= col("mt.end_date_nvl")) )
in Scala? Or something like below -
val output_df = spark
.table(MY_TABLE)
.withColumn(customer_id, col(beneficiary_external_id).cast("long"))
.withColumn(customer_type_id, upper(col(customer_type_id)))
.withColumn(customer_role_id, upper(col(customer_role_id)))
.filter(col(customer_type_id).isin("C"))
.filter(col(customer_role_id).isin("O", "G"))
.withColumn(is_owner, when(col(customer_role_id) === lit("O"), lit("Y")).otherwise(lit("N")))
.withColumn(is_guest, when(col(customer_role_id) === lit("G"), lit("Y")).otherwise(lit("N")))
.withColumn(start_date_truncated, col(start_date).cast("date"))
.withColumn(end_date_nvl, expr("nvl(cast(end_date as date), to_timestamp('2050-12-31', 'yyyy-MM-dd'))"))
// This is how I imagine the filter should work
.filter(
col("t_dates.calendar_day")
.between(
col(getAliasedColumnName(o_beneficiaries_alias, start_date_trunc)),
col(getAliasedColumnName(o_beneficiaries_alias, end_date_nvl))
)
)
.as(my_table_alias)
.select(
col("t_dates.calendar_day"),
col(getAliasedColumnName(my_table_alias, customer_id)),
col(getAliasedColumnName(my_table_alias, start_date_truncated)),
col(getAliasedColumnName(my_table_alias, end_date_nvl)),
col(getAliasedColumnName(my_table_alias, is_owner)),
col(getAliasedColumnName(my_table_alias, is_guest)),
col(getAliasedColumnName(my_table_alias, is_active))
)
But the above code obviously won't compile. I have read StackOverflow posts like this and still don't know how to properly use filter
for dates in between.
Q2: This is more deeper level question in that I wonder by using filter
as proposed above (not sure if it's possible) compared to using INNER JOIN
approach, Spark would be better optimized because it can trim a lot of data using some sort of filter in the background. I ask because right now, the above query would fail to run occasionally with the famous out-of-memory error when processing this table in particular.
Thanks very much in advance for your suggestions/answers!