1

I have an algebraic data type that I want to use as a parameter for a case class which looks like this:

sealed abstract class DayOfWeek(val id: String)

object DayOfWeek {
  final object Sunday    extends DayOfWeek("sunday")
  final object Monday    extends DayOfWeek("monday")
  final object Tuesday   extends DayOfWeek("tuesday")
  final object Wednesday extends DayOfWeek("wednesday")
  final object Thursday  extends DayOfWeek("thursday")
  final object Friday    extends DayOfWeek("friday")
  final object Saturday  extends DayOfWeek("saturday")

  val members: List[DayOfWeek] = List(Sunday, Monday, Tuesday, Wednesday, Thursday, Friday, Saturday)

  def apply(id: String): DayOfWeek = members
    .map(member => (member.id, member))
    .toMap
    .apply(id)
}

I have seen answers here which says there is no good way to do it, like this one. But I don't believe that is the case.

There seems to be a path using UserDefinedType and UDTRegistration.register. Those are marked private as of Spark 2.X, but I tried using them in the org.apache.spark namespace. That handles the private issue. But when I tried to call .toDS on a Seq[DayOfWeek], even after I call the register, it still says value toDS is not a member of Seq[DayOfWeek]. So it's not picking up that registration.

package org.apache.spark

object DayOfWeekUDT {
  def register(): Unit = UDTRegistration.register(classOf[DayOfWeek].getName, classOf[DayOfWeekUDT].getName)
}

class DayOfWeekUDT extends UserDefinedType[DayOfWeek] {
  override def sqlType: DataType = StringType
  override def serialize(obj: DayOfWeek): Any = org.apache.spark.unsafe.types.UTF8String.fromString(obj.id)
  override def deserialize(datum: Any): DayOfWeek = DayOfWeek(datum.toString)
  override def userClass: Class[DayOfWeek] = classOf[DayOfWeek]
}

There is also the creating an implicit val of type Encoder[DayOfWeek using ExpressionEncoder. I searched all of Github for examples. The only few I could find didn't apply to my specific need. And I couldn't understand them well enough to make my own version and get it to work. This SHOULD work with Spark 2.x (in my case 2.4.x). It's just a matter of figuring out how to use this tool. This is what I tried and had inside the DayOfWeek companion object:

  private val clazz: Class[DayOfWeek] = classOf[DayOfWeek]

  private val inputObject: BoundReference = BoundReference(0, ObjectType(clazz), false)

  private val converter = StaticInvoke(
    classOf[UTF8String],
    StringType,
    "fromString",
    Invoke(inputObject, "id", ObjectType(classOf[String])) :: Nil
  )

  private val serializer: Seq[Expression] = CreateNamedStruct(Literal("day_of_week") :: converter :: Nil).flatten

  private val deserializer: Expression = StaticInvoke(
    staticObject = DayOfWeek.getClass,
    dataType = ObjectType(clazz),
    functionName = "apply",
    arguments = Invoke(
      targetObject = UpCast(
        child = GetColumnByOrdinal(0, StringType),
        dataType = StringType,
        walkedTypePath = "- root class: DayOfWeek" :: Nil
      ),
      functionName = "id",
      dataType = ObjectType(classOf[String])
    ) :: Nil,
    propagateNull = false,
    returnNullable = false
  )

  implicit val encoder: Encoder[DayOfWeek] = new ExpressionEncoder[DayOfWeek](
    schema = StructType(Seq(StructField("id", StringType, false))),
    flat = true,
    serializer = serializer,
    deserializer = deserializer,
    clsTag = ClassTag(classOf[DayOfWeek])
  )

Does anyone have any idea how to do this properly? I like the UserDefinedType concept, in that the conversion back and forth with data type in Spark and the ADT is very clear. The ExpressionEncoder almost looks like it was written to be a cryptic as possible.

DCameronMauch
  • 323
  • 1
  • 9
  • 1
    The toDS does not work because your type is not a Product Type(case class) – Emiliano Martinez Oct 10 '20 at 09:46
  • Yes, it's not a product type. Even if I extended product, how would that help me? I could encode it as child class, and that works. But it can't decode it. There is no information to decode to a child class, and the parent is an abstract class and thus can't be directly instantiated. – DCameronMauch Oct 10 '20 at 14:31

1 Answers1

1

I know that question is pretty old, but still. As question you referenced pointed out there is no good solution for this problem. By good I mean generic, manageable and systematic way to put arbitrary objects into Spark Datasets. There's no such way. Spark could support sum ADT, but I believe that this limits Catalyst and Tungsten optimizations severely. Having said that, in this particular case you can invent some ad hoc way to stick object into Dataset.

Actually, your DayOfWeekUDT works fine. You can use it like this:

DayOfWeekUDT.register()
val schema = StructType(List(StructField("day_of_week", new DayOfWeekUDT())))
val values = List(DayOfWeek.Friday, DayOfWeek.Monday)
val df = spark.createDataFrame(values.map(RowFactory.create(_)).asJava, schema)
df.show(false)
df.printSchema()

As you see, you can create DataFrame of it. But you cannot turn this DataFrame into Dataset without Encoder. Also you cannot do values.toDF() because this also requires Encoder. So basically UserDefinedType is an old way of integrating custom types into Spark SQL, namely into DataFrames. Encoder is to integrate types into Datasets. They are orthogonal to each other. Dataset is a more modern API which evolved from the effort to combine good parts of RDD and DataFrame APIs. Eventually Dataset subsumed DataFrame which is now Dataset[Row]. I consider Encoder must be a primary way to integrate custom types. And UserDefinedType is deprecated, that's why it was made spark package scoped.

As for your Encoder implementation then I amended it this way:

  private val clazz: Class[DayOfWeek] = classOf[DayOfWeek]

  private val inputObject: BoundReference = BoundReference(0, ObjectType(clazz), false)

  private val converter = StaticInvoke(
    classOf[UTF8String],
    StringType,
    "fromString",
    Invoke(inputObject, "id", ObjectType(classOf[String])) :: Nil
  )

  private val serializer: Expression = CreateNamedStruct(Literal("day_of_week") :: converter :: Nil)

  private val deserializer: Expression = StaticInvoke(
    staticObject = DayOfWeek.getClass,
    dataType = ObjectType(clazz),
    functionName = "apply",
    arguments = Invoke(GetColumnByOrdinal(0, StringType), "toString", ObjectType(classOf[String])) :: Nil,
    propagateNull = false,
    returnNullable = false
  )

  implicit val encoder: Encoder[DayOfWeek] = new ExpressionEncoder[DayOfWeek](
    serializer,
    deserializer,
    clsTag = ClassTag(classOf[DayOfWeek])
  )

This compiles against Spark 3 and passes quick sanity test:

val values = List(DayOfWeek.Friday, DayOfWeek.Monday)
val ds = values.toDS()
ds.show()
ds.printSchema()
Rorick
  • 8,857
  • 3
  • 32
  • 37