0

I am still learning Scala and Spark, so I apologize if this is a basic question. I have looked through quite a few posts on StackOverflow and Google, and cannot find an answer to this yet. I have the following code -

import org.apache.spark.sql.{Column, SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import java.time.LocalDate
import java.time.format.DateTimeFormatter
import java.sql.Timestamp

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Dataset, Row, SparkSession}

import java.sql.Timestamp
import scala.collection.{Iterable, Map}

val DATE_PATTERN = "yyyy-MM-dd"
val start_day = LocalDate.parse("2023-01-01", DateTimeFormatter.ofPattern(DATE_PATTERN))
val end_day = LocalDate.parse("2023-01-06", DateTimeFormatter.ofPattern(DATE_PATTERN))
val start_day_literal = lit(Timestamp.valueOf(start_day.atStartOfDay())).cast("timestamp")
val end_day_literal = lit(Timestamp.valueOf(end_day.atStartOfDay())).cast("timestamp")
val t_dates =  spark.read
.parquet("s3://my-bucket/calendar_dates/")
.filter(col(calendar_day) >= start_day_literal)
.filter(col(calendar_day) <= end_day_literal)
.select(calendar_day)

t_dates.printSchema
t_dates.show(100, false)

// output
root
 |-- calendar_day: timestamp (nullable = true)

+-------------------+
|calendar_day       |
+-------------------+
|2023-01-01 00:00:00|
|2023-01-02 00:00:00|
|2023-01-03 00:00:00|
|2023-01-04 00:00:00|
|2023-01-05 00:00:00|
|2023-01-06 00:00:00|
+-------------------+

In the above, t_dates is just a table with a bunch of dates starting from Jan 1, 1990 to Dec 31, 2050. Now the code continues like this -

val MY_TABLE = "reporting.customers"
val my_table_alias = "mt"
val d_dates_alias = "dt"

val customer_id = "customer_id"
val customer_type_id = "customer_type_id"
val customer_role_id = "customer_role_id"

val calendar_day = "calendar_day"
val start_date = "start_date" // in 'yyyy-MM-dd HH:mm:ss' format
val start_date_truncated = "start_date_truncated" // truncated start_date to just 'yyyy-MM-dd' format
val end_date = "end_date"
val end_date_nvl = "end_date_nvl" // if customer is still active, this end_date_nvl will be set to Dec 31, 2050; otherwise, it's equal to end_date, which is when customer stops his/her membership

val is_owner = "is_owner"
val is_guest = "is_guest"
val is_active = "is_active"

def getAliasedColumnName(alias: String, columnName: String): String = {
 alias.concat(".").concat(columnName)
}

// This is the meat of the code that I have questions about
val output_df = spark
.table(MY_TABLE)
.withColumn(customer_id, col(beneficiary_external_id).cast("long"))
.withColumn(customer_type_id, upper(col(customer_type_id)))
.withColumn(customer_role_id, upper(col(customer_role_id)))
.filter(col(customer_type_id).isin("C"))
.filter(col(customer_role_id).isin("O", "G"))
.withColumn(is_owner, when(col(customer_role_id) === lit("O"), lit("Y")).otherwise(lit("N")))
.withColumn(is_guest, when(col(customer_role_id) === lit("G"), lit("Y")).otherwise(lit("N")))
.withColumn(start_date_truncated, col(start_date).cast("date"))
.withColumn(end_date_nvl, expr("nvl(cast(end_date as date), to_timestamp('2050-12-31', 'yyyy-MM-dd'))"))
.as(my_table_alias)
.join(
    broadcast(t_dates).as(d_dates_alias),
    joinExprs = col(getAliasedColumnName(d_dates_alias, calendar_day)).between(
        lowerBound = col(getAliasedColumnName(my_table_alias, start_date_truncated)),
        upperBound = col(getAliasedColumnName(my_table_alias, end_date_nvl))
    ),
    joinType = "inner"
)
.select(
    col(getAliasedColumnName(d_dates_alias, calendar_day)),
    col(getAliasedColumnName(my_table_alias, customer_id)),
    col(getAliasedColumnName(my_table_alias, start_date_truncated)),
    col(getAliasedColumnName(my_table_alias, end_date_nvl)),
    col(getAliasedColumnName(my_table_alias, is_owner)),
    col(getAliasedColumnName(my_table_alias, is_guest)),
    col(getAliasedColumnName(my_table_alias, is_active))
)

output_df.show(100, false);
output_df.count();

// output
+-------------------+------------+----------------+-------------------+--------+--------+---------+
|calendar_day       |customer_id |start_date_trunc|end_date_nvl       |is_owner|is_guest|is_active|
+-------------------+------------+----------------+-------------------+--------+--------+---------+
|2023-01-01 00:00:00|150634248002|2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |
|2023-01-02 00:00:00|150634248002|2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |
|2023-01-03 00:00:00|150634248002|2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |
|2023-01-04 00:00:00|150634248002|2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |
|2023-01-05 00:00:00|150634248002|2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |
|2023-01-06 00:00:00|150634248002|2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |
|2023-01-01 00:00:00|1067715662  |2019-12-25      |2050-12-31 00:00:00|Y       |N       |Y        |

... the data continues. There are a total of 800 million rows in this table

My questions are - Q1: Suppose t_dates.calendar_day is just 1-3 days at most (e.g., Jan 1 to Jan 3, 2023), how do I turn the INNER JOIN in the above code to something like .filter( (col("dt.calendar_day") >= col("mt.start_date_truncated")) && (col("dt.calendar_day") <= col("mt.end_date_nvl")) ) in Scala? Or something like below -

val output_df = spark
.table(MY_TABLE)
.withColumn(customer_id, col(beneficiary_external_id).cast("long"))
.withColumn(customer_type_id, upper(col(customer_type_id)))
.withColumn(customer_role_id, upper(col(customer_role_id)))
.filter(col(customer_type_id).isin("C"))
.filter(col(customer_role_id).isin("O", "G"))
.withColumn(is_owner, when(col(customer_role_id) === lit("O"), lit("Y")).otherwise(lit("N")))
.withColumn(is_guest, when(col(customer_role_id) === lit("G"), lit("Y")).otherwise(lit("N")))
.withColumn(start_date_truncated, col(start_date).cast("date"))
.withColumn(end_date_nvl, expr("nvl(cast(end_date as date), to_timestamp('2050-12-31', 'yyyy-MM-dd'))"))

// This is how I imagine the filter should work
.filter( 
   col("t_dates.calendar_day")
       .between(
         col(getAliasedColumnName(o_beneficiaries_alias, start_date_trunc)),
         col(getAliasedColumnName(o_beneficiaries_alias, end_date_nvl))
        )
      )

.as(my_table_alias)
.select(
    col("t_dates.calendar_day"),
    col(getAliasedColumnName(my_table_alias, customer_id)),
    col(getAliasedColumnName(my_table_alias, start_date_truncated)),
    col(getAliasedColumnName(my_table_alias, end_date_nvl)),
    col(getAliasedColumnName(my_table_alias, is_owner)),
    col(getAliasedColumnName(my_table_alias, is_guest)),
    col(getAliasedColumnName(my_table_alias, is_active))
)

But the above code obviously won't compile. I have read StackOverflow posts like this and still don't know how to properly use filter for dates in between.

Q2: This is more deeper level question in that I wonder by using filter as proposed above (not sure if it's possible) compared to using INNER JOIN approach, Spark would be better optimized because it can trim a lot of data using some sort of filter in the background. I ask because right now, the above query would fail to run occasionally with the famous out-of-memory error when processing this table in particular.

Thanks very much in advance for your suggestions/answers!

user1330974
  • 2,500
  • 5
  • 32
  • 60
  • 1
    Take a look to this answer: https://stackoverflow.com/a/56850663/6802156 – Emiliano Martinez Mar 01 '23 at 16:59
  • @EmilianoMartinez Thank you. This is way more complicated than I anticipated. :) I was able to go with something like this `.filter(col("start_date_truncated") <= calendar_day).filter(col(end_date_nvl) >= calendar_day)` A bit more verbose than I'd like, but it worked. Thank you again for sharing this fancy/advance way of doing things! – user1330974 Mar 02 '23 at 01:39

0 Answers0