7

I am trying to port python code (spark sql distance to nearest holiday)

last_holiday = index.value[0]
    for next_holiday in index.value:
        if next_holiday >= date:
            break
        last_holiday = next_holiday
    if last_holiday > date:
        last_holiday = None
    if next_holiday < date:
        next_holiday = None

to scala. I do not (yet) have so much scala experience, but break does not seem clean / the scala way to do it. Please, can you show me how to "cleanly" port this to scala.

breakable {
      for (next_holiday <- indexAT.value) {
        val next = next_holiday.toLocalDate
        println("next ", next)
        println("last ", last_holiday)

        if (next.isAfter(current) || next.equals(current)) break
        // check do I actually get here?
        last_holiday = Option(next)
      } // TODO this is so not scala and ugly ...
      if (last_holiday.isDefined) {
        if (last_holiday.get.isAfter(current)) {
          last_holiday = None
        }
      }
      if (last_holiday.isDefined) {
        if (last_holiday.get.isBefore(current)) {
          // TODO use one more var because out of scope
          next = None
        }
      }
    }

Here the same code in a bit more context https://gist.github.com/geoHeil/ff513b97a2b3e16241fdd9c8b0f3bdfb Also, I am not sure how "big" I should put the break - but I hope to get rid of it in a scala native port of the code.

Community
  • 1
  • 1
Georg Heiler
  • 16,916
  • 36
  • 162
  • 292
  • What is `date` in `if next_holiday >= date:`? If I understand correctly, you want a function that takes a collection of holidays, checks that the current date is not a holiday and finds the number of days between the current (today's) date and the next holiday right? – airudah Nov 23 '16 at 16:32

2 Answers2

2

So this isn't a direct port but I think it is closer to idiomatic Scala. I would treat the list of holidays as a list of sequential pairs and then find which pair the input date lies between.

Here is a full example:

scala> import java.sql.Date
import java.sql.Date

scala> import java.text.SimpleDateFormat
import java.text.SimpleDateFormat

scala> :pa
// Entering paste mode (ctrl-D to finish)
def parseDate(in: String): java.sql.Date =
{
    val formatter = new SimpleDateFormat("MM/dd/yyyy")
    val d = formatter.parse(in)
    new java.sql.Date(d.getTime());
}
// Exiting paste mode, now interpreting.
parseDate: (in: String)java.sql.Date

scala> val holidays = Seq("11/24/2016", "12/25/2016", "12/31/2016").map(parseDate)
holidays: Seq[java.sql.Date] = List(2016-11-24, 2016-12-25, 2016-12-31)

scala> val hP = sc.broadcast(holidays.zip(holidays.tail))
hP: org.apache.spark.broadcast.Broadcast[Seq[(java.sql.Date, java.sql.Date)]] = Broadcast(4)

scala> def geq(d1: Date, d2: Date) = d1.after(d2) || d1.equals(d2)
geq: (d1: java.sql.Date, d2: java.sql.Date)Boolean

scala> def leq(d1: Date, d2: Date) = d1.before(d2) || d1.equals(d2)
leq: (d1: java.sql.Date, d2: java.sql.Date)Boolean

scala> :pa
// Entering paste mode (ctrl-D to finish)
val findNearestHolliday = udf((inDate: Date) => {
    val hP_l = hP.value
    val dates = hP_l.collectFirst{case (d1, d2) if (geq(inDate, d1) && leq(inDate, d2)) => (Some(d1), Some(d2))}
    dates.getOrElse(if (leq(inDate, hP_l.head._1)) (None, Some(hP_l.head._1)) else (Some(hP_l.last._2), None))
})
// Exiting paste mode, now interpreting.
findNearestHolliday: org.apache.spark.sql.UserDefinedFunction = UserDefinedFunction(<function1>,StructType(StructField(_1,DateType,true), StructField(_2,DateType,true)),List(DateType))

scala> val df = Seq((1, parseDate("11/01/2016")), (2, parseDate("12/01/2016")), (3, parseDate("01/01/2017"))).toDF("id", "date")
df: org.apache.spark.sql.DataFrame = [id: int, date: date]

scala> val df2 = df.withColumn("nearestHollidays", findNearestHolliday($"date"))
df2: org.apache.spark.sql.DataFrame = [id: int, date: date, nearestHollidays: struct<_1:date,_2:date>]

scala> df2.show
+---+----------+--------------------+
| id|      date|    nearestHollidays|
+---+----------+--------------------+
|  1|2016-11-01|   [null,2016-11-24]|
|  2|2016-12-01|[2016-11-24,2016-...|
|  3|2017-01-01|   [2016-12-31,null]|
+---+----------+--------------------+

scala> df2.foreach{println}
[3,2017-01-01,[2016-12-31,null]]
[1,2016-11-01,[null,2016-11-24]]
[2,2016-12-01,[2016-11-24,2016-12-25]]
evan.oman
  • 5,922
  • 22
  • 43
  • May I ask why you `holidays.dropRight(1).zip(holidays.drop(1)` use this expression? – Georg Heiler Nov 23 '16 at 19:49
  • That is the way I generate the sequential pairs: `List(1,2,3,4)` would become `List((1,2), (2,3), (3,4))`. – evan.oman Nov 23 '16 at 19:54
  • I just replaced it with `holidays.zip(holidays.tail)` which accomplishes the same thing but is more efficient – evan.oman Nov 23 '16 at 20:00
  • @GeorgHeiler let me know if you have any other questions or if this solution does not solve your problem – evan.oman Nov 23 '16 at 20:27
  • Sure - just want to try this tomorrow on my data. Thanks a lot so far. May I ask why you chose not to use a findFirst(current > holiday) assuming holiday is from a sorted Seq? – Georg Heiler Nov 23 '16 at 20:37
  • I cannot find a `findFirst` method for [`Seq`](http://www.scala-lang.org/api/current/scala/collection/immutable/Seq.html), however `collectFirst` is described as "Finds the first element of the traversable or iterator for which the given partial function is defined, and applies the partial function to it." so it is accomplishing precisely what we want. We could also use `find` and then take the head of that but `collectFirst` is defined for this exact purpose – evan.oman Nov 23 '16 at 20:43
  • could this be optimized assuming both df's (holidays and dates) are sorted? – Georg Heiler Nov 24 '16 at 07:25
  • Let us [continue this discussion in chat](http://chat.stackoverflow.com/rooms/128910/discussion-between-georg-heiler-and-evan058). – Georg Heiler Nov 24 '16 at 07:32
0

I've made an attempt to implement this with scala:

scala> import java.text.SimpleDateFormat
import java.text.SimpleDateFormat

scala> import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit

scala> val sdf = new SimpleDateFormat("dd/MM/yyyy")
sdf: java.text.SimpleDateFormat = java.text.SimpleDateFormat@d936eac0

//Here I've just assumed that the 15th of every other month is a public holiday
scala> val publicHolidays = for(interval <- 4 to 12 by 2) yield sdf.parse(s"15/$interval/2016")
publicHolidays: scala.collection.immutable.IndexedSeq[java.util.Date] = Vector(Fri Apr 15 00:00:00 BST 2016, Wed Jun 15 00:00:00 BST 2016, Mon Aug 15 00:00:00 BST 2016, Sat Oct 15 00:00:00 BST 2016, Thu Dec 15 00:00:00 GMT 2016)

//Today's date
scala> val currentDate = sdf.parse("23/11/2016")
currentDate: java.util.Date = Wed Nov 23 00:00:00 GMT 2016

scala> def findDaysTillNextHoliday: Long = {
     | val nextHolday = publicHolidays.toList.filter(_.after(currentDate)).head
     | TimeUnit.DAYS.convert(nextHolday.getTime - currentDate.getTime, TimeUnit.MILLISECONDS)
     | }
findDaysTillNextHoliday: Long

scala> findDaysTillNextHoliday
res0: Long = 22 //i.e 22 days till the next holiday which is 15th of december 2016

And for days since last holiday:

def findDaysSinceLastHoliday: Long = {
      | val lastHoliday = publicHolidays.toList.filter(_.before(currentDate)).last
      | TimeUnit.DAYS.convert(currentDate.getTime - lastHoliday.getTime, TimeUnit.MILLISECONDS)
      |}
findDaysSinceLastHoliday: Long

scala> findDaysSinceLastHoliday
res1: Long = 39 //i.e 39 days since the last holiday which was 15th of October 2016
airudah
  • 1,169
  • 12
  • 19