0

I'm kind like discovering the macros for an use case in which I tried to extract the lambda arg names from a function. To do so, I've defined this class (let's say in a module A):

object MacroTest {

  def getLambdaArgNames[A, B](f: A => B): String = macro getLambdaArgNamesImpl[A, B]

  def getLambdaArgNamesImpl[A, B](c: Context)(f: c.Expr[A => B]): c.Expr[String] = {
    import c.universe._
    val Function(args, body) = f.tree
    val names = args.map(_.name)
    val argNames = names.mkString(", ")
    val constant = Literal(Constant(argNames))
    c.Expr[String](q"$constant")
  }

Now in another module, I'm trying to write an unit like to check the names of the argument named passed to a lambda :

class TestSomething extends AnyFreeSpec with Matchers {
  "test" in {
    val f = (e1: Expr[Int]) => e1 === 3
    val argNames = MacroTest.getLambdaArgNames(f)
    println(argNames)
    assert(argNames === "e1")
  }
}

But this code doesn't compile because of : scala.MatchError: f (of class scala.reflect.internal.Trees$Ident)

But if I pass directly the lambda to the function like MacroTest.getLambdaArgNames((e1: Expr[Int]) => e1 === 3) it's working so I'm pretty lost about the reason that makes the code not compiling.

Any possible solution to fix that ?

Pᴇʜ
  • 56,719
  • 10
  • 49
  • 73
alifirat
  • 2,899
  • 1
  • 17
  • 33

2 Answers2

0

It's the same issue as in Scala macro inspect tree for anonymous function, where you commented: you really do need to pass the lambda itself to the macro, as you say

if I pass directly the lambda to the function like MacroTest.getLambdaArgNames((e1: Expr[Int]) => e1 === 3) it's working

When you write MacroTest.getLambdaArgNames(f), the AST argument (f: c.Expr[A => B] in getLambdaArgNamesImpl) just stores the identifier f and the lambda's AST isn't stored anywhere.

Alternately, you can store the AST for the lambda in f, not the the lambda itself:

val f = q"(e1: Expr[Int]) => e1 === 3"

in Scala 2,

val f = '{ (e1: Expr[Int]) => e1 === 3 }

in Scala 3, and then make getLambdaArgNames a normal function.

Alexey Romanov
  • 167,066
  • 35
  • 309
  • 487
  • So if I understand well, there is no way to get the source of the lambda directly from the lambda itself ? – alifirat Mar 17 '22 at 15:31
  • Yes. The information is simply not stored anywhere after compilation (there are languages which do, e.g. JavaScript, but not Scala). _Maybe_ for this specific case you could use reflection instead, but it depends on compilation details which I don't know at the moment and could easily change between Scala versions. – Alexey Romanov Mar 18 '22 at 20:34
0

Try approach with Traverser

def getLambdaArgNamesImpl[A, B](c: blackbox.Context)(f: c.Expr[A => B]): c.Expr[String] = {
  import c.universe._

  val arguments = f.tree match {
    case Function(args, body) => args

    case _ =>
      var rhs: Option[Tree] = None

      val traverser = new Traverser {
        override def traverse(tree: Tree): Unit = {
          tree match {
            case q"$_ val f: $_ = $expr"
              if tree.symbol == f.tree.symbol ||
                (tree.symbol.isTerm && tree.symbol.asTerm.getter == f.tree.symbol) =>
              rhs = Some(expr)
            case _ => super.traverse(tree)
          }
        }
      }

      c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))

      rhs match {
        case Some(Function(args, body)) => args
        case _ => c.abort(c.enclosingPosition, "can't find definition of val f")
      }
  }

  val names = arguments.map(_.name)
  val argNames = names.mkString(", ")
  val constant = Literal(Constant(argNames))
  c.Expr[String](q"$constant")
}

Case tree.symbol == f.tree.symbol matches when f is a local variable

class TestSomething extends AnyFreeSpec with Matchers {
  "test" in {
    val f = (e1: Expr[Int]) => e1 === 3
    ...

Case tree.symbol.asTerm.getter == f.tree.symbol matches when f is a field in a class

class TestSomething extends AnyFreeSpec with Matchers {
  val f = (e1: Expr[Int]) => e1 === 3
  "test" in {
    ...

Def Macro, pass parameter from a value

Creating a method definition tree from a method symbol and a body

Scala macro how to convert a MethodSymbol to DefDef with parameter default values?

How to get the runtime value of parameter passed to a Scala macro?

Dmytro Mitin
  • 48,194
  • 3
  • 28
  • 66