1

The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.

Say the function impl() returns an anonymous function:

trait Base {}
class A extends Base{
  def impl(): Function1[Int, String] = new Function1[Int, String] {
    def apply(x: Int): String = "ab" + x.toString
  }
}
val classes = reflections.getSubTypesOf(classOf[Base]).toSet[Class[_ <: Base]].toList

and I obtain the anonymous function in another place:

val clazz = classes(0)
val instance = clazz.newInstance()
val impl = clazz.getDeclaredMethod("impl").invoke(instance)

Now, impl holds the anonymous function but I do not know its signature, and I'd like to ask whether we can convert it into a correct function instance:

impl.asInstanceOf[Function1[Int, String]]   // How to determine the function signature of the anonymous function, in this case Function1[Int, String]?

Since scala does not support generic function, I first consider getting the runtime type of the function:

import scala.reflect.runtime.universe.{TypeTag, typeTag}
def getTypeTag[T: TypeTag](obj: T) = typeTag[T]
val typeList = getTypeTag(impl).tpe.typeArgs

It will return List(Int, String), but I fail to recognize the correct function template via reflection.

Update: if the classes are defined as follows:

trait Base {}
class A extends Base{
  def impl(x: Int): String = {
    "ab" + x.toString
  }
}

where impl is the function itself and we do not know its function signature, can the impl function still be registered?

Dmytro Mitin
  • 48,194
  • 3
  • 28
  • 66
Hang Wu
  • 15
  • 4
  • 3
    Can you clarify what you have as input? And what you'd like to be able to do. – Gaël J Mar 09 '23 at 12:03
  • 1
    It's not clear what you're asking. *"scala does not support generic function"* What does this mean? *"I fail to recognize the correct function template"* What does this mean? Why from `List(Int, String)` can't you conclude that this is a `Function1[Int, String]`? – Dmytro Mitin Mar 09 '23 at 12:18
  • `getTypeTag(impl())` is `TypeTag[Int => String]` aka `TypeTag[Function1[Int, String]]`. – Dmytro Mitin Mar 09 '23 at 12:25
  • 1
    It sounds as if OP received `Function1[_, _]` where input and output types were lost due type type erasure. However, there is no context so it's hard to tell why it was lost, or how OP can be sure that what they receive is `Function1[Int, String]` or anything else. – Mateusz Kubuszok Mar 09 '23 at 20:45
  • 1
    The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so. – Hang Wu Mar 10 '23 at 02:12
  • https://stackoverflow.com/questions/56991021/how-to-get-the-static-type-of-an-expression-in-scala https://stackoverflow.com/questions/62344603/scala-defining-custom-type-type-mismatch-error https://stackoverflow.com/questions/72151052/the-opposite-of-compiletime-constvaluet-in-scala-3 https://stackoverflow.com/questions/64220935/type-annotations-overriden-by-inferred-expression-type – Dmytro Mitin Mar 10 '23 at 05:22
  • 1
    If you use runtime reflection, all bets are off. Either you'd extract the input and output from somewhere else if possible - because the bytecode for `Function1` might not have it - or if you have a control over how this function is defined, pass some information next to the value. Hard to say without looking at any code that defines and extracts this value. – Mateusz Kubuszok Mar 10 '23 at 13:12
  • https://stackoverflow.com/questions/75224941/scala-cast-object-based-on-reflection-symbol https://stackoverflow.com/questions/74488026/public-method-must-have-explicit-type https://stackoverflow.com/questions/7858588/how-do-i-view-the-type-of-a-scala-expression-in-intellij https://stackoverflow.com/questions/19386964/i-want-to-get-the-type-of-a-variable-at-runtime Also type classes `shapeless.ops.function.{FnFromProduct, FnToProduct}`/`scala.util.TupledFunction`. – Dmytro Mitin Mar 10 '23 at 13:48

1 Answers1

1

The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.

Normally you register a UDF as follows

import org.apache.spark.sql.SparkSession

object App {
  val spark = SparkSession.builder
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  def impl(): Int => String = x => "ab" + x.toString

  spark.udf.register("foo", impl())

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
    //+-------+
    //|foo(10)|
    //+-------+
    //|   ab10|
    //+-------+
  }
}

The signature of register is

def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction

aka

def register[RT, A1](name: String, func: Function1[A1, RT])(implicit
  ttag:  TypeTag[RT],
  ttag1: TypeTag[A1]
): UserDefinedFunction

What TypeTag normally does is persisting a type information from compile time to runtime.

So in order to call register you either have to know types at compile time or have to know how to construct type tags at runtime.

If you don't have access to how impl() is constructed at runtime and you don't have (at least at runtime) the information about types/type tags at all then unfortunately this type information is irreversibly lost because of the type erasure (Function1[Int, String] is just Function1[_,_] at runtime)

def impl(): Any = (x: Int) => "ab" + x.toString

But it's possible that you have access to how impl() is constructed at runtime and you know (at least at runtime) the information about types/type tags. So I assume that you don't have types Int, String statically and you can't call typeTag[Int], typeTag[String] (as I do below) but you have somehow runtime objects of Type/TypeTag

import org.apache.spark.sql.catalyst.ScalaReflection.universe._

def impl(): Any = (x: Int) => "ab" + x.toString
val ttag1 = typeTag[Int]    // actual definition is probably different
val ttag  = typeTag[String] // actual definition is probably different

In such case you can call register resolving implicits explicitly

spark.udf.register("foo", impl().asInstanceOf[Function1[_,_]])(ttag.asInstanceOf[TypeTag[_]], ttag1.asInstanceOf[TypeTag[_]])

Well, this doesn't compile because of existential types but you can trick the compiler

type A
type B
spark.udf.register("foo", impl().asInstanceOf[A => B])(ttag.asInstanceOf[TypeTag[B]], ttag1.asInstanceOf[TypeTag[A]])

https://gist.github.com/DmytroMitin/0b3660d646f74fb109665bad41b3ae9f

Alternatively you can use runtime compilation (creating a new compile time inside the runtime)

import org.apache.spark.sql.catalyst.ScalaReflection
import ScalaReflection.universe._
import scala.tools.reflect.ToolBox // libraryDependencies += scalaOrganization.value % "scala-compiler" % scalaVersion.value

val rm = ScalaReflection.mirror
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", App.impl().asInstanceOf[$ttag1 => $ttag])""")

https://gist.github.com/DmytroMitin/5b5dd4d7db0d0eebb51dd8c16735e0fb

You should provide some code how you construct impl() and we'll see whether it's possible to restore the types.

Spark registered a Scala object all of the methods as a UDF

scala cast object based on reflection symbol


Update. After you get val impl = clazz.getDeclaredMethod("impl").invoke(instance) it's too late to restore function types (you can check that typeList is empty). Where function type (or type tag) should be captured is somewhere not too far from class A, maybe inside A or outside A but when Int, String are not lost yet. What TypeTag can do is persisting type information from compile time to runtime, it can't restore type information at runtime if it's lost.

import org.apache.spark.sql.catalyst.ScalaReflection
import ScalaReflection.universe._
import org.apache.spark.sql.SparkSession
import org.reflections.Reflections
import scala.jdk.CollectionConverters._
import scala.reflect.api

object App {
  def getType[T: TypeTag](obj: T) = typeOf[T]

  trait Base
  class A extends Base {
    def impl(): Int => String = x => "ab" + x.toString 

       // NotSerializableException
    //def impl(): Function1[Int, String] = new Function1[Int, String] {
    //  def apply(x: Int): String = "ab" + x.toString
    //}

    val tpe = getType(impl())
  }

  val reflections = new Reflections()
  val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList

  val clazz = classes(0)
  val instance = clazz.newInstance()
  val impl = clazz.getDeclaredMethod("impl").invoke(instance)
  val functionType = clazz.getDeclaredMethod("tpe").invoke(instance).asInstanceOf[Type]
  val List(argType, returnType) = functionType.typeArgs

  val spark = SparkSession.builder()
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  val rm = ScalaReflection.mirror

  // (*)
  def typeToTypeTag[T](tpe: Type): TypeTag[T] =
    TypeTag(rm, new api.TypeCreator {
      def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
        tpe.asInstanceOf[U#Type]
    })

//  type X
//  type Y
//  spark.udf.register("foo", impl.asInstanceOf[X => Y])(
//    typeToTypeTag[Y](returnType),
//    typeToTypeTag[X](argType)
//  )

  impl match {
    case impl: Function1[x, y] => spark.udf.register("foo", impl)(
      typeToTypeTag[y](returnType),
      typeToTypeTag[x](argType)
    )
  }

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
  }

}

https://gist.github.com/DmytroMitin/2ebfae922f8a467d01b6ef18c8b8e5ad

(*) Get a TypeTag from a Type?

Now spark.sql("""SELECT foo(10)""").show() throws java.io.NotSerializableException but I guess it's not related to reflection.

Alternatively you can use runtime compilation (instead of manual resolution of implicits and construction of type tags from types)

import scala.tools.reflect.ToolBox

val rm = ScalaReflection.mirror
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", App.impl.asInstanceOf[$functionType])""")

https://gist.github.com/DmytroMitin/ba469faeca2230890845e1532b36e2a1

One more option is to request the return type of method impl() as soon as we get class A (outside A)

class A extends Base {
  def impl(): Int => String = x => "ab" + x.toString
}

// ...
val functionType = rm.classSymbol(clazz).typeSignature.decl(TermName("impl")).asMethod.returnType
val List(argType, returnType) = functionType.typeArgs

https://gist.github.com/DmytroMitin/3bd2c19d158f8241a80952c397ee5e09


Update 2. If the methods are defined as follows:

class A extends Base{
  def impl(x: Int): String = {
    "ab" + x.toString
  }
}

then runtime compilation normally should be

val rm = ScalaReflection.mirror
val classSymbol = rm.classSymbol(clazz)
val tb = rm.mkToolBox()

tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).$methodSymbol(_))""")

or

tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).impl(_))""")

but now with Spark it produces ClassCastException: cannot assign instance of java.lang.invoke.SerializedLambda to field org.apache.spark.sql.catalyst.expressions.ScalaUDF.f of type scala.Function1 in instance of org.apache.spark.sql.catalyst.expressions.ScalaUDF similarly to Spark registered a Scala object all of the methods as a UDF

https://gist.github.com/DmytroMitin/b0f110f4cf15e2dfd4add70f7124a7b6

But ordinary Scala runtime reflection seems to work

val rm = ScalaReflection.mirror
val classSymbol = rm.classSymbol(clazz)
val methodSymbol = classSymbol.typeSignature.decl(TermName("impl")).asMethod
val returnType = methodSymbol.returnType
val argType = methodSymbol.paramLists.head.head.typeSignature

val constructorSymbol = classSymbol.typeSignature.decl(termNames.CONSTRUCTOR).asMethod
val instance = rm.reflectClass(classSymbol).reflectConstructor(constructorSymbol)()
val impl: Any => Any = rm.reflect(instance).reflectMethod(methodSymbol)(_)

def typeToTypeTag[T](tpe: Type): TypeTag[T] =
  TypeTag(rm, new api.TypeCreator {
    def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
      tpe.asInstanceOf[U#Type]
  })

impl match {
  case impl: Function1[x, y] => spark.udf.register("foo", impl)(
    typeToTypeTag[y](returnType),
    typeToTypeTag[x](argType)
  )
}

https://gist.github.com/DmytroMitin/763751096fe9cdb2e0d18ae4b9290a54


Update 3. One more approach is to use compile-time reflection (macros) rather than runtime reflection if you have enough information at compile time (e.g. if all the classes are known at compile time)

import scala.collection.mutable
import scala.language.experimental.macros
import scala.reflect.macros.blackbox

object Macros {
  def registerMethod[A](): Unit = macro registerMethodImpl[A]

  def registerMethodImpl[A: c.WeakTypeTag](c: blackbox.Context)(): c.Tree = {
    import c.universe._
    val A = weakTypeOf[A]

    var children = mutable.Seq[Type]()

    val traverser = new Traverser {
      override def traverse(tree: Tree): Unit = {
        tree match {
          case _: ClassDef =>
            val tpe = tree.symbol.asClass.toType
            if (tpe <:< A && !(tpe =:= A)) children :+= tpe
          case _ =>
        }

        super.traverse(tree)
      }
    }

    c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))

    val calls = children.map(tpe =>
      q"""spark.udf.register("foo", (new $tpe).impl(_))"""
    )

    q"..$calls"
  }
}
// in a different subproject

import org.apache.spark.sql.SparkSession

object App {
  trait Base

  class A extends Base {
    def impl(x: Int): String = "ab" + x.toString
  }

  val spark = SparkSession.builder()
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  Macros.registerMethod[Base]()

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
  }
}

https://gist.github.com/DmytroMitin/6623f1f900330f8341f209e1347a0007

Shapeless - How to derive LabelledGeneric for Coproduct (KnownSubclasses)


Update 4. If we replace val clazz = classes.head with classes.foreach(clazz => ... then issues with NotSerializableException can be fixed with inlining

import scala.language.experimental.macros
import scala.reflect.macros.blackbox

object Macros {
  def registerMethod(clazz: Class[_]): Unit = macro registerMethodImpl

  def registerMethodImpl(c: blackbox.Context)(clazz: c.Tree): c.Tree = {
    import c.universe._

    val ScalaReflection = q"_root_.org.apache.spark.sql.catalyst.ScalaReflection"
    val rm = q"$ScalaReflection.mirror"
    val ru = q"$ScalaReflection.universe"
    val classSymbol = q"$rm.classSymbol($clazz)"
    val methodSymbol = q"""$classSymbol.typeSignature.decl($ru.TermName("impl")).asMethod"""
    val returnType = q"$methodSymbol.returnType"
    val argType = q"$methodSymbol.paramLists.head.head.typeSignature"

    val constructorSymbol = q"$classSymbol.typeSignature.decl($ru.termNames.CONSTRUCTOR).asMethod"
    val instance = q"$rm.reflectClass($classSymbol).reflectConstructor($constructorSymbol).apply()"
    val impl1 = q"(x: Any) => $rm.reflect($instance).reflectMethod($methodSymbol).apply(x)"
    val api = q"_root_.scala.reflect.api"

    def typeToTypeTag(T: Tree, tpe: Tree): Tree =
      q"""
        $ru.TypeTag[$T]($rm, new $api.TypeCreator {
          override def apply[U <: $api.Universe with _root_.scala.Singleton](m: $api.Mirror[U]) =
            $tpe.asInstanceOf[U#Type]
        })
      """

    val impl2 = TermName(c.freshName("impl2"))
    val x = TypeName(c.freshName("x"))
    val y = TypeName(c.freshName("y"))
    q"""
      $impl1 match {
        case $impl2: _root_.scala.Function1[$x, $y] => spark.udf.register("foo", $impl2)(
          ${typeToTypeTag(tq"$y", returnType)},
          ${typeToTypeTag(tq"$x", argType)}
        )
      }
    """
  }
}
// in a different subproject

import org.apache.spark.sql.SparkSession
import org.reflections.Reflections
import scala.jdk.CollectionConverters._

trait Base
class A extends Base /*with Serializable*/ {
  def impl(x: Int): String = "ab" + x.toString
}

object App {
  val spark: SparkSession = SparkSession.builder()
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  val reflections = new Reflections()
  val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList

  classes.foreach(clazz =>
    Macros.registerMethod(clazz)
  )

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
  }
}

https://gist.github.com/DmytroMitin/c926158a9ff94a6539097c603bbedf6a

Dmytro Mitin
  • 48,194
  • 3
  • 28
  • 66
  • Thanks for your suggestion. The case is that impl() is defined in other class and I just want to use reflection to achieve the goal to "register any anonymous function as UDF", so we cannot just write "spark.udf.register("udf", new Class1.impl())" directly. – Hang Wu Mar 11 '23 at 03:37
  • The example: ``` type A type B spark.udf.register("foo", impl().asInstanceOf[A => B])(ttag.asInstanceOf[TypeTag[B]], ttag1.asInstanceOf[TypeTag[A]]) ``` does not compile also. The compile complains "Block cannot contain declarations". – Hang Wu Mar 11 '23 at 03:50
  • @HangWu Please don't write code in comments. You should better update your question (press "edit"). – Dmytro Mitin Mar 11 '23 at 03:52
  • The last runtime compilation also doesn't work. The compile complains "Can't unquote Object, consider providing an implicit instance of Liftable[Object]". – Hang Wu Mar 11 '23 at 03:55
  • @HangWu Both of them compile https://gist.github.com/DmytroMitin/0b3660d646f74fb109665bad41b3ae9f https://gist.github.com/DmytroMitin/5b5dd4d7db0d0eebb51dd8c16735e0fb Please edit your question with details how you register UDF, how you implement `impl()` etc. – Dmytro Mitin Mar 11 '23 at 04:14
  • 1
    I've added more context. – Hang Wu Mar 11 '23 at 04:24
  • Awesome! The last remaining problem is how to deal with serialization. Just adding "@transient" before impl() does not work :( – Hang Wu Mar 11 '23 at 06:47
  • 1
    "def impl(): Function1[Int, String] = (x: Int) => "ab" + x.toString" It works if I rewrite impl as above :) – Hang Wu Mar 11 '23 at 06:55
  • If the classes are defined such that impl is the function directly: class A extends Base { def impl(x: Int): String = {"ab" + x.toString} }, I was wondering whether it's possible to register impl via reflection. – Hang Wu Mar 12 '23 at 03:35
  • If I replace the line "val clazz = classes.head" with "classes.forEach(clazz => {", the compiler throws a runtime error "Caused by: java.io.NotSerializableException: scala.reflect.runtime.JavaMirrors$JavaMirror". I tried but cannot fix it. Is this a bug? – Hang Wu Mar 13 '23 at 06:29
  • 1
    Even adding a pair of braces doesn't run as well https://gist.github.com/abcwuhang/4327ecb941f7fbb40514f1f0426c302e. – Hang Wu Mar 13 '23 at 06:37
  • @HangWu Not sure. Something very Spark-specific. By the way, see one more update with macro approach. – Dmytro Mitin Mar 13 '23 at 08:40
  • @HangWu Regarding your gist. If I move `val rm = ScalaReflection.mirror` ouside `{...}` block then `java.io.NotSerializableException: scala.reflect.runtime.JavaMirrors$JavaMirror` changes to `java.io.NotSerializableException: A`. So I make `class A extends Base with Serializable`. Now the error is `java.io.NotSerializableException: scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anon$12` (`object not serializable (class: ...$$anon$12, value: method impl)`). Not necessarily a bug. Just Spark doesn't know how to serialize something from our reflection staff. – Dmytro Mitin Mar 13 '23 at 18:53
  • @HangWu I finally managed to fix `NotSerializableException` but only with inlining https://gist.github.com/DmytroMitin/c926158a9ff94a6539097c603bbedf6a – Dmytro Mitin Mar 14 '23 at 01:32
  • https://medium.com/swlh/spark-serialization-errors-e0eebcf0f6e6 https://medium.com/onzo-tech/serialization-challenges-with-spark-and-scala-a2287cd51c54 https://medium.com/onzo-tech/serialization-challenges-with-spark-and-scala-part-2-now-for-something-really-challenging-bd0f391bd142 https://stackoverflow.com/questions/35337577/scala-reflection-with-serialization-over-spark-symbols-not-serializable – Dmytro Mitin Mar 14 '23 at 01:33
  • Cool. Seems that scala may integrate reflection with serialization better. – Hang Wu Mar 14 '23 at 12:45
  • @HangWu Scala or Spark? ;) – Dmytro Mitin Mar 14 '23 at 13:49
  • @HangWu By the way, In Scala 3 you don't need [macros](https://docs.scala-lang.org/scala3/reference/metaprogramming/macros.html) for inlining, [inline methods](https://docs.scala-lang.org/scala3/reference/metaprogramming/inline.html) are enough. – Dmytro Mitin Mar 14 '23 at 13:55
  • 1
    I use scala 2.12.8. – Hang Wu Mar 15 '23 at 02:59